Use TypedDict for voice list and TTS stream (#320)
Signed-off-by: rany <rany2@riseup.net>
This commit is contained in:
1
setup.py
1
setup.py
@@ -6,5 +6,6 @@ setup(
|
||||
install_requires=[
|
||||
"aiohttp>=3.8.0",
|
||||
"certifi>=2023.11.17",
|
||||
"typing-extensions>=4.1.0",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -35,6 +35,7 @@ from .exceptions import (
|
||||
WebSocketError,
|
||||
)
|
||||
from .models import TTSConfig
|
||||
from .typing import TTSChunk
|
||||
|
||||
|
||||
def get_headers_and_data(
|
||||
@@ -297,7 +298,7 @@ class Communicate:
|
||||
"stream_was_called": False,
|
||||
}
|
||||
|
||||
def __parse_metadata(self, data: bytes) -> Dict[str, Any]:
|
||||
def __parse_metadata(self, data: bytes) -> TTSChunk:
|
||||
for meta_obj in json.loads(data)["Metadata"]:
|
||||
meta_type = meta_obj["Type"]
|
||||
if meta_type == "WordBoundary":
|
||||
@@ -316,7 +317,7 @@ class Communicate:
|
||||
raise UnknownResponse(f"Unknown metadata type: {meta_type}")
|
||||
raise UnexpectedResponse("No WordBoundary metadata found")
|
||||
|
||||
async def __stream(self) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
async def __stream(self) -> AsyncGenerator[TTSChunk, None]:
|
||||
async def send_command_request() -> None:
|
||||
"""Sends the request to the service."""
|
||||
|
||||
@@ -479,7 +480,7 @@ class Communicate:
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
) -> AsyncGenerator[TTSChunk, None]:
|
||||
"""
|
||||
Streams audio and metadata from the service.
|
||||
|
||||
|
||||
84
src/edge_tts/typing.py
Normal file
84
src/edge_tts/typing.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Custom types for edge-tts."""
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
|
||||
from typing import List
|
||||
|
||||
from typing_extensions import Literal, NotRequired, TypedDict
|
||||
|
||||
|
||||
class TTSChunk(TypedDict):
|
||||
"""TTS chunk data."""
|
||||
|
||||
type: Literal["audio", "WordBoundary"]
|
||||
data: NotRequired[bytes] # only for audio
|
||||
duration: NotRequired[float] # only for WordBoundary
|
||||
offset: NotRequired[float] # only for WordBoundary
|
||||
text: NotRequired[str] # only for WordBoundary
|
||||
|
||||
|
||||
class VoiceTag(TypedDict):
|
||||
"""VoiceTag data."""
|
||||
|
||||
ContentCategories: List[
|
||||
Literal[
|
||||
"Cartoon",
|
||||
"Conversation",
|
||||
"Copilot",
|
||||
"Dialect",
|
||||
"General",
|
||||
"News",
|
||||
"Novel",
|
||||
"Sports",
|
||||
]
|
||||
]
|
||||
VoicePersonalities: List[
|
||||
Literal[
|
||||
"Approachable",
|
||||
"Authentic",
|
||||
"Authority",
|
||||
"Bright",
|
||||
"Caring",
|
||||
"Casual",
|
||||
"Cheerful",
|
||||
"Clear",
|
||||
"Comfort",
|
||||
"Confident",
|
||||
"Considerate",
|
||||
"Conversational",
|
||||
"Cute",
|
||||
"Expressive",
|
||||
"Friendly",
|
||||
"Honest",
|
||||
"Humorous",
|
||||
"Lively",
|
||||
"Passion",
|
||||
"Pleasant",
|
||||
"Positive",
|
||||
"Professional",
|
||||
"Rational",
|
||||
"Reliable",
|
||||
"Sincere",
|
||||
"Sunshine",
|
||||
"Warm",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class Voice(TypedDict):
|
||||
"""Voice data."""
|
||||
|
||||
Name: str
|
||||
ShortName: str
|
||||
Gender: Literal["Female", "Male"]
|
||||
Locale: str
|
||||
SuggestedCodec: Literal["audio-24khz-48kbitrate-mono-mp3"]
|
||||
FriendlyName: str
|
||||
Status: Literal["GA"]
|
||||
VoiceTag: VoiceTag
|
||||
|
||||
|
||||
class VoiceManagerVoice(Voice):
|
||||
"""Voice data for VoiceManager."""
|
||||
|
||||
Language: str
|
||||
@@ -3,18 +3,19 @@ correct voice based on their attributes."""
|
||||
|
||||
import json
|
||||
import ssl
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import certifi
|
||||
|
||||
from .constants import SEC_MS_GEC_VERSION, VOICE_HEADERS, VOICE_LIST
|
||||
from .drm import DRM
|
||||
from .typing import Voice, VoiceManagerVoice
|
||||
|
||||
|
||||
async def __list_voices(
|
||||
session: aiohttp.ClientSession, ssl_ctx: ssl.SSLContext, proxy: Optional[str]
|
||||
) -> Any:
|
||||
) -> List[Voice]:
|
||||
"""
|
||||
Private function that makes the request to the voice list URL and parses the
|
||||
JSON response. This function is used by list_voices() and makes it easier to
|
||||
@@ -26,7 +27,7 @@ async def __list_voices(
|
||||
proxy (Optional[str]): The proxy to use for the request.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of voice attributes.
|
||||
List[Voice]: A list of voices and their attributes.
|
||||
"""
|
||||
async with session.get(
|
||||
f"{VOICE_LIST}&Sec-MS-GEC={DRM.generate_sec_ms_gec()}"
|
||||
@@ -36,11 +37,25 @@ async def __list_voices(
|
||||
ssl=ssl_ctx,
|
||||
raise_for_status=True,
|
||||
) as url:
|
||||
data = json.loads(await url.text())
|
||||
data: List[Voice] = json.loads(await url.text())
|
||||
|
||||
for voice in data:
|
||||
# Remove leading and trailing whitespace from categories and personalities.
|
||||
# This has only happened in one case with the zh-CN-YunjianNeural voice
|
||||
# where there was a leading space in one of the categories.
|
||||
voice["VoiceTag"]["ContentCategories"] = [
|
||||
category.strip() # type: ignore
|
||||
for category in voice["VoiceTag"]["ContentCategories"]
|
||||
]
|
||||
voice["VoiceTag"]["VoicePersonalities"] = [
|
||||
personality.strip() # type: ignore
|
||||
for personality in voice["VoiceTag"]["VoicePersonalities"]
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def list_voices(*, proxy: Optional[str] = None) -> Any:
|
||||
async def list_voices(*, proxy: Optional[str] = None) -> List[Voice]:
|
||||
"""
|
||||
List all available voices and their attributes.
|
||||
|
||||
@@ -51,7 +66,7 @@ async def list_voices(*, proxy: Optional[str] = None) -> Any:
|
||||
proxy (Optional[str]): The proxy to use for the request.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of voice attributes.
|
||||
List[Voice]: A list of voices and their attributes.
|
||||
"""
|
||||
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
@@ -72,26 +87,23 @@ class VoicesManager:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.voices: List[Dict[str, Any]] = []
|
||||
self.voices: List[VoiceManagerVoice] = []
|
||||
self.called_create: bool = False
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls: Any, custom_voices: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Any:
|
||||
async def create(cls: Any, custom_voices: Optional[List[Voice]] = None) -> Any:
|
||||
"""
|
||||
Creates a VoicesManager object and populates it with all available voices.
|
||||
"""
|
||||
self = VoicesManager()
|
||||
self.voices = await list_voices() if custom_voices is None else custom_voices
|
||||
voices = await list_voices() if custom_voices is None else custom_voices
|
||||
self.voices = [
|
||||
{**voice, **{"Language": voice["Locale"].split("-")[0]}}
|
||||
for voice in self.voices
|
||||
{**voice, "Language": voice["Locale"].split("-")[0]} for voice in voices
|
||||
]
|
||||
self.called_create = True
|
||||
return self
|
||||
|
||||
def find(self, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||
def find(self, **kwargs: Any) -> List[VoiceManagerVoice]:
|
||||
"""
|
||||
Finds all matching voices based on the provided attributes.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user