Improve type support a bit more (#333)

Also fix default voice for util.py

Signed-off-by: rany <rany2@riseup.net>
This commit is contained in:
Rany
2024-11-23 19:59:39 +02:00
committed by GitHub
parent a3d468c7c9
commit 063957683c
7 changed files with 57 additions and 32 deletions

View File

@@ -20,6 +20,3 @@ warn_unreachable = True
strict_equality = True strict_equality = True
strict = True strict = True
[mypy-edge_tts.voices]
disallow_any_decorated = False

View File

@@ -11,7 +11,6 @@ from contextlib import nullcontext
from io import TextIOWrapper from io import TextIOWrapper
from queue import Queue from queue import Queue
from typing import ( from typing import (
Any,
AsyncGenerator, AsyncGenerator,
ContextManager, ContextManager,
Dict, Dict,
@@ -26,7 +25,8 @@ from xml.sax.saxutils import escape
import aiohttp import aiohttp
import certifi import certifi
from .constants import SEC_MS_GEC_VERSION, WSS_HEADERS, WSS_URL from .constants import DEFAULT_VOICE, SEC_MS_GEC_VERSION, WSS_HEADERS, WSS_URL
from .data_classes import TTSConfig
from .drm import DRM from .drm import DRM
from .exceptions import ( from .exceptions import (
NoAudioReceived, NoAudioReceived,
@@ -34,8 +34,7 @@ from .exceptions import (
UnknownResponse, UnknownResponse,
WebSocketError, WebSocketError,
) )
from .models import TTSConfig from .typing import CommunicateState, TTSChunk
from .typing import TTSChunk
def get_headers_and_data( def get_headers_and_data(
@@ -109,7 +108,7 @@ def split_text_by_byte_length(
text will be inside of an XML tag. text will be inside of an XML tag.
Args: Args:
text (str or bytes): The string to be split. text (str or bytes): The string to be split. If bytes, it must be UTF-8 encoded.
byte_length (int): The maximum byte length of each string in the list. byte_length (int): The maximum byte length of each string in the list.
Yield: Yield:
@@ -166,12 +165,9 @@ def mkssml(tc: TTSConfig, escaped_text: Union[str, bytes]) -> str:
Returns: Returns:
str: The SSML string. str: The SSML string.
""" """
# If the text is bytes, convert it to a string.
if isinstance(escaped_text, bytes): if isinstance(escaped_text, bytes):
escaped_text = escaped_text.decode("utf-8") escaped_text = escaped_text.decode("utf-8")
# Return the SSML string.
return ( return (
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>" "<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>"
f"<voice name='{tc.voice}'>" f"<voice name='{tc.voice}'>"
@@ -244,7 +240,7 @@ class Communicate:
def __init__( def __init__(
self, self,
text: str, text: str,
voice: str = "en-US-EmmaMultilingualNeural", voice: str = DEFAULT_VOICE,
*, *,
rate: str = "+0%", rate: str = "+0%",
volume: str = "+0%", volume: str = "+0%",
@@ -290,8 +286,8 @@ class Communicate:
self.connector: Optional[aiohttp.BaseConnector] = connector self.connector: Optional[aiohttp.BaseConnector] = connector
# Store current state of TTS. # Store current state of TTS.
self.state: Dict[str, Any] = { self.state: CommunicateState = {
"partial_text": None, "partial_text": b"",
"offset_compensation": 0, "offset_compensation": 0,
"last_duration_offset": 0, "last_duration_offset": 0,
"stream_was_called": False, "stream_was_called": False,

View File

@@ -6,6 +6,8 @@ TRUSTED_CLIENT_TOKEN = "6A5AA1D4EAFF4E9FB37E23D68491D6F4"
WSS_URL = f"wss://{BASE_URL}/edge/v1?TrustedClientToken={TRUSTED_CLIENT_TOKEN}" WSS_URL = f"wss://{BASE_URL}/edge/v1?TrustedClientToken={TRUSTED_CLIENT_TOKEN}"
VOICE_LIST = f"https://{BASE_URL}/voices/list?trustedclienttoken={TRUSTED_CLIENT_TOKEN}" VOICE_LIST = f"https://{BASE_URL}/voices/list?trustedclienttoken={TRUSTED_CLIENT_TOKEN}"
DEFAULT_VOICE = "en-US-EmmaMultilingualNeural"
CHROMIUM_FULL_VERSION = "130.0.2849.68" CHROMIUM_FULL_VERSION = "130.0.2849.68"
CHROMIUM_MAJOR_VERSION = CHROMIUM_FULL_VERSION.split(".", maxsplit=1)[0] CHROMIUM_MAJOR_VERSION = CHROMIUM_FULL_VERSION.split(".", maxsplit=1)[0]
SEC_MS_GEC_VERSION = f"1-{CHROMIUM_FULL_VERSION}" SEC_MS_GEC_VERSION = f"1-{CHROMIUM_FULL_VERSION}"

View File

@@ -1,6 +1,8 @@
"""This module contains the TTSConfig dataclass, which represents the """Data models for edge-tts."""
internal TTS configuration for edge-tts's Communicate class."""
# pylint: disable=too-few-public-methods
import argparse
import re import re
from dataclasses import dataclass from dataclasses import dataclass
@@ -69,3 +71,18 @@ class TTSConfig:
self.validate_string_param("rate", self.rate, r"^[+-]\d+%$") self.validate_string_param("rate", self.rate, r"^[+-]\d+%$")
self.validate_string_param("volume", self.volume, r"^[+-]\d+%$") self.validate_string_param("volume", self.volume, r"^[+-]\d+%$")
self.validate_string_param("pitch", self.pitch, r"^[+-]\d+Hz$") self.validate_string_param("pitch", self.pitch, r"^[+-]\d+Hz$")
class UtilArgs(argparse.Namespace):
"""CLI arguments."""
text: str
file: str
voice: str
list_voices: bool
rate: str
volume: str
pitch: str
write_media: str
write_subtitles: str
proxy: str

View File

@@ -78,15 +78,24 @@ class Voice(TypedDict):
VoiceTag: VoiceTag VoiceTag: VoiceTag
class VoiceManagerVoice(Voice): class VoicesManagerVoice(Voice):
"""Voice data for VoiceManager.""" """Voice data for VoicesManager."""
Language: str Language: str
class VoiceManagerFind(TypedDict): class VoicesManagerFind(TypedDict):
"""Voice data for VoiceManager.find().""" """Voice data for VoicesManager.find()."""
Gender: NotRequired[Literal["Female", "Male"]] Gender: NotRequired[Literal["Female", "Male"]]
Locale: NotRequired[str] Locale: NotRequired[str]
Language: NotRequired[str] Language: NotRequired[str]
class CommunicateState(TypedDict):
"""Communicate state data."""
partial_text: bytes
offset_compensation: float
last_duration_offset: float
stream_was_called: bool

View File

@@ -3,14 +3,16 @@
import argparse import argparse
import asyncio import asyncio
import sys import sys
from typing import Any, Optional, TextIO from typing import Optional, TextIO
from tabulate import tabulate from tabulate import tabulate
from . import Communicate, SubMaker, list_voices from . import Communicate, SubMaker, list_voices
from .constants import DEFAULT_VOICE
from .data_classes import UtilArgs
async def _print_voices(*, proxy: str) -> None: async def _print_voices(*, proxy: Optional[str]) -> None:
"""Print all available voices.""" """Print all available voices."""
voices = await list_voices(proxy=proxy) voices = await list_voices(proxy=proxy)
voices = sorted(voices, key=lambda voice: voice["ShortName"]) voices = sorted(voices, key=lambda voice: voice["ShortName"])
@@ -27,7 +29,7 @@ async def _print_voices(*, proxy: str) -> None:
print(tabulate(table, headers)) print(tabulate(table, headers))
async def _run_tts(args: Any) -> None: async def _run_tts(args: UtilArgs) -> None:
"""Run TTS after parsing arguments from command line.""" """Run TTS after parsing arguments from command line."""
try: try:
@@ -84,15 +86,17 @@ async def _run_tts(args: Any) -> None:
async def amain() -> None: async def amain() -> None:
"""Async main function""" """Async main function"""
parser = argparse.ArgumentParser(description="Microsoft Edge TTS") parser = argparse.ArgumentParser(
description="Text-to-speech using Microsoft Edge's online TTS service."
)
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-t", "--text", help="what TTS will say") group.add_argument("-t", "--text", help="what TTS will say")
group.add_argument("-f", "--file", help="same as --text but read from file") group.add_argument("-f", "--file", help="same as --text but read from file")
parser.add_argument( parser.add_argument(
"-v", "-v",
"--voice", "--voice",
help="voice for TTS. Default: en-US-AriaNeural", help=f"voice for TTS. Default: {DEFAULT_VOICE}",
default="en-US-AriaNeural", default=DEFAULT_VOICE,
) )
group.add_argument( group.add_argument(
"-l", "-l",
@@ -111,7 +115,7 @@ async def amain() -> None:
help="send subtitle output to provided file instead of stderr", help="send subtitle output to provided file instead of stderr",
) )
parser.add_argument("--proxy", help="use a proxy for TTS and voice list.") parser.add_argument("--proxy", help="use a proxy for TTS and voice list.")
args = parser.parse_args() args = parser.parse_args(namespace=UtilArgs())
if args.list_voices: if args.list_voices:
await _print_voices(proxy=args.proxy) await _print_voices(proxy=args.proxy)

View File

@@ -3,7 +3,7 @@ correct voice based on their attributes."""
import json import json
import ssl import ssl
from typing import Any, List, Optional from typing import List, Optional
import aiohttp import aiohttp
import certifi import certifi
@@ -11,7 +11,7 @@ from typing_extensions import Unpack
from .constants import SEC_MS_GEC_VERSION, VOICE_HEADERS, VOICE_LIST from .constants import SEC_MS_GEC_VERSION, VOICE_HEADERS, VOICE_LIST
from .drm import DRM from .drm import DRM
from .typing import Voice, VoiceManagerFind, VoiceManagerVoice from .typing import Voice, VoicesManagerFind, VoicesManagerVoice
async def __list_voices( async def __list_voices(
@@ -91,12 +91,12 @@ class VoicesManager:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.voices: List[VoiceManagerVoice] = [] self.voices: List[VoicesManagerVoice] = []
self.called_create: bool = False self.called_create: bool = False
@classmethod @classmethod
async def create( async def create(
cls: Any, custom_voices: Optional[List[Voice]] = None cls, custom_voices: Optional[List[Voice]] = None
) -> "VoicesManager": ) -> "VoicesManager":
""" """
Creates a VoicesManager object and populates it with all available voices. Creates a VoicesManager object and populates it with all available voices.
@@ -109,7 +109,7 @@ class VoicesManager:
self.called_create = True self.called_create = True
return self return self
def find(self, **kwargs: Unpack[VoiceManagerFind]) -> List[VoiceManagerVoice]: def find(self, **kwargs: Unpack[VoicesManagerFind]) -> List[VoicesManagerVoice]:
""" """
Finds all matching voices based on the provided attributes. Finds all matching voices based on the provided attributes.
""" """