Improve type support a bit more (#333)
Also fix default voice for util.py Signed-off-by: rany <rany2@riseup.net>
This commit is contained in:
3
mypy.ini
3
mypy.ini
@@ -20,6 +20,3 @@ warn_unreachable = True
|
|||||||
|
|
||||||
strict_equality = True
|
strict_equality = True
|
||||||
strict = True
|
strict = True
|
||||||
|
|
||||||
[mypy-edge_tts.voices]
|
|
||||||
disallow_any_decorated = False
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from contextlib import nullcontext
|
|||||||
from io import TextIOWrapper
|
from io import TextIOWrapper
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
Dict,
|
Dict,
|
||||||
@@ -26,7 +25,8 @@ from xml.sax.saxutils import escape
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import certifi
|
import certifi
|
||||||
|
|
||||||
from .constants import SEC_MS_GEC_VERSION, WSS_HEADERS, WSS_URL
|
from .constants import DEFAULT_VOICE, SEC_MS_GEC_VERSION, WSS_HEADERS, WSS_URL
|
||||||
|
from .data_classes import TTSConfig
|
||||||
from .drm import DRM
|
from .drm import DRM
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
NoAudioReceived,
|
NoAudioReceived,
|
||||||
@@ -34,8 +34,7 @@ from .exceptions import (
|
|||||||
UnknownResponse,
|
UnknownResponse,
|
||||||
WebSocketError,
|
WebSocketError,
|
||||||
)
|
)
|
||||||
from .models import TTSConfig
|
from .typing import CommunicateState, TTSChunk
|
||||||
from .typing import TTSChunk
|
|
||||||
|
|
||||||
|
|
||||||
def get_headers_and_data(
|
def get_headers_and_data(
|
||||||
@@ -109,7 +108,7 @@ def split_text_by_byte_length(
|
|||||||
text will be inside of an XML tag.
|
text will be inside of an XML tag.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str or bytes): The string to be split.
|
text (str or bytes): The string to be split. If bytes, it must be UTF-8 encoded.
|
||||||
byte_length (int): The maximum byte length of each string in the list.
|
byte_length (int): The maximum byte length of each string in the list.
|
||||||
|
|
||||||
Yield:
|
Yield:
|
||||||
@@ -166,12 +165,9 @@ def mkssml(tc: TTSConfig, escaped_text: Union[str, bytes]) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
str: The SSML string.
|
str: The SSML string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If the text is bytes, convert it to a string.
|
|
||||||
if isinstance(escaped_text, bytes):
|
if isinstance(escaped_text, bytes):
|
||||||
escaped_text = escaped_text.decode("utf-8")
|
escaped_text = escaped_text.decode("utf-8")
|
||||||
|
|
||||||
# Return the SSML string.
|
|
||||||
return (
|
return (
|
||||||
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>"
|
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>"
|
||||||
f"<voice name='{tc.voice}'>"
|
f"<voice name='{tc.voice}'>"
|
||||||
@@ -244,7 +240,7 @@ class Communicate:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
voice: str = "en-US-EmmaMultilingualNeural",
|
voice: str = DEFAULT_VOICE,
|
||||||
*,
|
*,
|
||||||
rate: str = "+0%",
|
rate: str = "+0%",
|
||||||
volume: str = "+0%",
|
volume: str = "+0%",
|
||||||
@@ -290,8 +286,8 @@ class Communicate:
|
|||||||
self.connector: Optional[aiohttp.BaseConnector] = connector
|
self.connector: Optional[aiohttp.BaseConnector] = connector
|
||||||
|
|
||||||
# Store current state of TTS.
|
# Store current state of TTS.
|
||||||
self.state: Dict[str, Any] = {
|
self.state: CommunicateState = {
|
||||||
"partial_text": None,
|
"partial_text": b"",
|
||||||
"offset_compensation": 0,
|
"offset_compensation": 0,
|
||||||
"last_duration_offset": 0,
|
"last_duration_offset": 0,
|
||||||
"stream_was_called": False,
|
"stream_was_called": False,
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ TRUSTED_CLIENT_TOKEN = "6A5AA1D4EAFF4E9FB37E23D68491D6F4"
|
|||||||
WSS_URL = f"wss://{BASE_URL}/edge/v1?TrustedClientToken={TRUSTED_CLIENT_TOKEN}"
|
WSS_URL = f"wss://{BASE_URL}/edge/v1?TrustedClientToken={TRUSTED_CLIENT_TOKEN}"
|
||||||
VOICE_LIST = f"https://{BASE_URL}/voices/list?trustedclienttoken={TRUSTED_CLIENT_TOKEN}"
|
VOICE_LIST = f"https://{BASE_URL}/voices/list?trustedclienttoken={TRUSTED_CLIENT_TOKEN}"
|
||||||
|
|
||||||
|
DEFAULT_VOICE = "en-US-EmmaMultilingualNeural"
|
||||||
|
|
||||||
CHROMIUM_FULL_VERSION = "130.0.2849.68"
|
CHROMIUM_FULL_VERSION = "130.0.2849.68"
|
||||||
CHROMIUM_MAJOR_VERSION = CHROMIUM_FULL_VERSION.split(".", maxsplit=1)[0]
|
CHROMIUM_MAJOR_VERSION = CHROMIUM_FULL_VERSION.split(".", maxsplit=1)[0]
|
||||||
SEC_MS_GEC_VERSION = f"1-{CHROMIUM_FULL_VERSION}"
|
SEC_MS_GEC_VERSION = f"1-{CHROMIUM_FULL_VERSION}"
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""This module contains the TTSConfig dataclass, which represents the
|
"""Data models for edge-tts."""
|
||||||
internal TTS configuration for edge-tts's Communicate class."""
|
|
||||||
|
|
||||||
|
# pylint: disable=too-few-public-methods
|
||||||
|
|
||||||
|
import argparse
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@@ -69,3 +71,18 @@ class TTSConfig:
|
|||||||
self.validate_string_param("rate", self.rate, r"^[+-]\d+%$")
|
self.validate_string_param("rate", self.rate, r"^[+-]\d+%$")
|
||||||
self.validate_string_param("volume", self.volume, r"^[+-]\d+%$")
|
self.validate_string_param("volume", self.volume, r"^[+-]\d+%$")
|
||||||
self.validate_string_param("pitch", self.pitch, r"^[+-]\d+Hz$")
|
self.validate_string_param("pitch", self.pitch, r"^[+-]\d+Hz$")
|
||||||
|
|
||||||
|
|
||||||
|
class UtilArgs(argparse.Namespace):
|
||||||
|
"""CLI arguments."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
file: str
|
||||||
|
voice: str
|
||||||
|
list_voices: bool
|
||||||
|
rate: str
|
||||||
|
volume: str
|
||||||
|
pitch: str
|
||||||
|
write_media: str
|
||||||
|
write_subtitles: str
|
||||||
|
proxy: str
|
||||||
@@ -78,15 +78,24 @@ class Voice(TypedDict):
|
|||||||
VoiceTag: VoiceTag
|
VoiceTag: VoiceTag
|
||||||
|
|
||||||
|
|
||||||
class VoiceManagerVoice(Voice):
|
class VoicesManagerVoice(Voice):
|
||||||
"""Voice data for VoiceManager."""
|
"""Voice data for VoicesManager."""
|
||||||
|
|
||||||
Language: str
|
Language: str
|
||||||
|
|
||||||
|
|
||||||
class VoiceManagerFind(TypedDict):
|
class VoicesManagerFind(TypedDict):
|
||||||
"""Voice data for VoiceManager.find()."""
|
"""Voice data for VoicesManager.find()."""
|
||||||
|
|
||||||
Gender: NotRequired[Literal["Female", "Male"]]
|
Gender: NotRequired[Literal["Female", "Male"]]
|
||||||
Locale: NotRequired[str]
|
Locale: NotRequired[str]
|
||||||
Language: NotRequired[str]
|
Language: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
class CommunicateState(TypedDict):
|
||||||
|
"""Communicate state data."""
|
||||||
|
|
||||||
|
partial_text: bytes
|
||||||
|
offset_compensation: float
|
||||||
|
last_duration_offset: float
|
||||||
|
stream_was_called: bool
|
||||||
|
|||||||
@@ -3,14 +3,16 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Optional, TextIO
|
from typing import Optional, TextIO
|
||||||
|
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
from . import Communicate, SubMaker, list_voices
|
from . import Communicate, SubMaker, list_voices
|
||||||
|
from .constants import DEFAULT_VOICE
|
||||||
|
from .data_classes import UtilArgs
|
||||||
|
|
||||||
|
|
||||||
async def _print_voices(*, proxy: str) -> None:
|
async def _print_voices(*, proxy: Optional[str]) -> None:
|
||||||
"""Print all available voices."""
|
"""Print all available voices."""
|
||||||
voices = await list_voices(proxy=proxy)
|
voices = await list_voices(proxy=proxy)
|
||||||
voices = sorted(voices, key=lambda voice: voice["ShortName"])
|
voices = sorted(voices, key=lambda voice: voice["ShortName"])
|
||||||
@@ -27,7 +29,7 @@ async def _print_voices(*, proxy: str) -> None:
|
|||||||
print(tabulate(table, headers))
|
print(tabulate(table, headers))
|
||||||
|
|
||||||
|
|
||||||
async def _run_tts(args: Any) -> None:
|
async def _run_tts(args: UtilArgs) -> None:
|
||||||
"""Run TTS after parsing arguments from command line."""
|
"""Run TTS after parsing arguments from command line."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -84,15 +86,17 @@ async def _run_tts(args: Any) -> None:
|
|||||||
|
|
||||||
async def amain() -> None:
|
async def amain() -> None:
|
||||||
"""Async main function"""
|
"""Async main function"""
|
||||||
parser = argparse.ArgumentParser(description="Microsoft Edge TTS")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Text-to-speech using Microsoft Edge's online TTS service."
|
||||||
|
)
|
||||||
group = parser.add_mutually_exclusive_group(required=True)
|
group = parser.add_mutually_exclusive_group(required=True)
|
||||||
group.add_argument("-t", "--text", help="what TTS will say")
|
group.add_argument("-t", "--text", help="what TTS will say")
|
||||||
group.add_argument("-f", "--file", help="same as --text but read from file")
|
group.add_argument("-f", "--file", help="same as --text but read from file")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-v",
|
"-v",
|
||||||
"--voice",
|
"--voice",
|
||||||
help="voice for TTS. Default: en-US-AriaNeural",
|
help=f"voice for TTS. Default: {DEFAULT_VOICE}",
|
||||||
default="en-US-AriaNeural",
|
default=DEFAULT_VOICE,
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"-l",
|
"-l",
|
||||||
@@ -111,7 +115,7 @@ async def amain() -> None:
|
|||||||
help="send subtitle output to provided file instead of stderr",
|
help="send subtitle output to provided file instead of stderr",
|
||||||
)
|
)
|
||||||
parser.add_argument("--proxy", help="use a proxy for TTS and voice list.")
|
parser.add_argument("--proxy", help="use a proxy for TTS and voice list.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args(namespace=UtilArgs())
|
||||||
|
|
||||||
if args.list_voices:
|
if args.list_voices:
|
||||||
await _print_voices(proxy=args.proxy)
|
await _print_voices(proxy=args.proxy)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ correct voice based on their attributes."""
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import ssl
|
import ssl
|
||||||
from typing import Any, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import certifi
|
import certifi
|
||||||
@@ -11,7 +11,7 @@ from typing_extensions import Unpack
|
|||||||
|
|
||||||
from .constants import SEC_MS_GEC_VERSION, VOICE_HEADERS, VOICE_LIST
|
from .constants import SEC_MS_GEC_VERSION, VOICE_HEADERS, VOICE_LIST
|
||||||
from .drm import DRM
|
from .drm import DRM
|
||||||
from .typing import Voice, VoiceManagerFind, VoiceManagerVoice
|
from .typing import Voice, VoicesManagerFind, VoicesManagerVoice
|
||||||
|
|
||||||
|
|
||||||
async def __list_voices(
|
async def __list_voices(
|
||||||
@@ -91,12 +91,12 @@ class VoicesManager:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.voices: List[VoiceManagerVoice] = []
|
self.voices: List[VoicesManagerVoice] = []
|
||||||
self.called_create: bool = False
|
self.called_create: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def create(
|
||||||
cls: Any, custom_voices: Optional[List[Voice]] = None
|
cls, custom_voices: Optional[List[Voice]] = None
|
||||||
) -> "VoicesManager":
|
) -> "VoicesManager":
|
||||||
"""
|
"""
|
||||||
Creates a VoicesManager object and populates it with all available voices.
|
Creates a VoicesManager object and populates it with all available voices.
|
||||||
@@ -109,7 +109,7 @@ class VoicesManager:
|
|||||||
self.called_create = True
|
self.called_create = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def find(self, **kwargs: Unpack[VoiceManagerFind]) -> List[VoiceManagerVoice]:
|
def find(self, **kwargs: Unpack[VoicesManagerFind]) -> List[VoicesManagerVoice]:
|
||||||
"""
|
"""
|
||||||
Finds all matching voices based on the provided attributes.
|
Finds all matching voices based on the provided attributes.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user