Use TypedDict for voice list and TTS stream (#320)

Signed-off-by: rany <rany2@riseup.net>
This commit is contained in:
Rany
2024-11-22 20:04:52 +02:00
committed by GitHub
parent 48c7f3ad2e
commit 3e4de19344
4 changed files with 115 additions and 17 deletions

View File

@@ -6,5 +6,6 @@ setup(
install_requires=[
"aiohttp>=3.8.0",
"certifi>=2023.11.17",
"typing-extensions>=4.1.0",
],
)

View File

@@ -35,6 +35,7 @@ from .exceptions import (
WebSocketError,
)
from .models import TTSConfig
from .typing import TTSChunk
def get_headers_and_data(
@@ -297,7 +298,7 @@ class Communicate:
"stream_was_called": False,
}
def __parse_metadata(self, data: bytes) -> Dict[str, Any]:
def __parse_metadata(self, data: bytes) -> TTSChunk:
for meta_obj in json.loads(data)["Metadata"]:
meta_type = meta_obj["Type"]
if meta_type == "WordBoundary":
@@ -316,7 +317,7 @@ class Communicate:
raise UnknownResponse(f"Unknown metadata type: {meta_type}")
raise UnexpectedResponse("No WordBoundary metadata found")
async def __stream(self) -> AsyncGenerator[Dict[str, Any], None]:
async def __stream(self) -> AsyncGenerator[TTSChunk, None]:
async def send_command_request() -> None:
"""Sends the request to the service."""
@@ -479,7 +480,7 @@ class Communicate:
async def stream(
self,
) -> AsyncGenerator[Dict[str, Any], None]:
) -> AsyncGenerator[TTSChunk, None]:
"""
Streams audio and metadata from the service.

84
src/edge_tts/typing.py Normal file
View File

@@ -0,0 +1,84 @@
"""Custom types for edge-tts."""
# pylint: disable=too-few-public-methods
from typing import List
from typing_extensions import Literal, NotRequired, TypedDict
class TTSChunk(TypedDict):
"""TTS chunk data."""
type: Literal["audio", "WordBoundary"]
data: NotRequired[bytes] # only for audio
duration: NotRequired[float] # only for WordBoundary
offset: NotRequired[float] # only for WordBoundary
text: NotRequired[str] # only for WordBoundary
class VoiceTag(TypedDict):
"""VoiceTag data."""
ContentCategories: List[
Literal[
"Cartoon",
"Conversation",
"Copilot",
"Dialect",
"General",
"News",
"Novel",
"Sports",
]
]
VoicePersonalities: List[
Literal[
"Approachable",
"Authentic",
"Authority",
"Bright",
"Caring",
"Casual",
"Cheerful",
"Clear",
"Comfort",
"Confident",
"Considerate",
"Conversational",
"Cute",
"Expressive",
"Friendly",
"Honest",
"Humorous",
"Lively",
"Passion",
"Pleasant",
"Positive",
"Professional",
"Rational",
"Reliable",
"Sincere",
"Sunshine",
"Warm",
]
]
class Voice(TypedDict):
"""Voice data."""
Name: str
ShortName: str
Gender: Literal["Female", "Male"]
Locale: str
SuggestedCodec: Literal["audio-24khz-48kbitrate-mono-mp3"]
FriendlyName: str
Status: Literal["GA"]
VoiceTag: VoiceTag
class VoiceManagerVoice(Voice):
"""Voice data for VoiceManager."""
Language: str

View File

@@ -3,18 +3,19 @@ correct voice based on their attributes."""
import json
import ssl
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional
import aiohttp
import certifi
from .constants import SEC_MS_GEC_VERSION, VOICE_HEADERS, VOICE_LIST
from .drm import DRM
from .typing import Voice, VoiceManagerVoice
async def __list_voices(
session: aiohttp.ClientSession, ssl_ctx: ssl.SSLContext, proxy: Optional[str]
) -> Any:
) -> List[Voice]:
"""
Private function that makes the request to the voice list URL and parses the
JSON response. This function is used by list_voices() and makes it easier to
@@ -26,7 +27,7 @@ async def __list_voices(
proxy (Optional[str]): The proxy to use for the request.
Returns:
dict: A dictionary of voice attributes.
List[Voice]: A list of voices and their attributes.
"""
async with session.get(
f"{VOICE_LIST}&Sec-MS-GEC={DRM.generate_sec_ms_gec()}"
@@ -36,11 +37,25 @@ async def __list_voices(
ssl=ssl_ctx,
raise_for_status=True,
) as url:
data = json.loads(await url.text())
data: List[Voice] = json.loads(await url.text())
for voice in data:
# Remove leading and trailing whitespace from categories and personalities.
# This has only happened in one case with the zh-CN-YunjianNeural voice
# where there was a leading space in one of the categories.
voice["VoiceTag"]["ContentCategories"] = [
category.strip() # type: ignore
for category in voice["VoiceTag"]["ContentCategories"]
]
voice["VoiceTag"]["VoicePersonalities"] = [
personality.strip() # type: ignore
for personality in voice["VoiceTag"]["VoicePersonalities"]
]
return data
async def list_voices(*, proxy: Optional[str] = None) -> Any:
async def list_voices(*, proxy: Optional[str] = None) -> List[Voice]:
"""
List all available voices and their attributes.
@@ -51,7 +66,7 @@ async def list_voices(*, proxy: Optional[str] = None) -> Any:
proxy (Optional[str]): The proxy to use for the request.
Returns:
dict: A dictionary of voice attributes.
List[Voice]: A list of voices and their attributes.
"""
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
async with aiohttp.ClientSession(trust_env=True) as session:
@@ -72,26 +87,23 @@ class VoicesManager:
"""
def __init__(self) -> None:
self.voices: List[Dict[str, Any]] = []
self.voices: List[VoiceManagerVoice] = []
self.called_create: bool = False
@classmethod
async def create(
cls: Any, custom_voices: Optional[List[Dict[str, Any]]] = None
) -> Any:
async def create(cls: Any, custom_voices: Optional[List[Voice]] = None) -> Any:
"""
Creates a VoicesManager object and populates it with all available voices.
"""
self = VoicesManager()
self.voices = await list_voices() if custom_voices is None else custom_voices
voices = await list_voices() if custom_voices is None else custom_voices
self.voices = [
{**voice, **{"Language": voice["Locale"].split("-")[0]}}
for voice in self.voices
{**voice, "Language": voice["Locale"].split("-")[0]} for voice in voices
]
self.called_create = True
return self
def find(self, **kwargs: Any) -> List[Dict[str, Any]]:
def find(self, **kwargs: Any) -> List[VoiceManagerVoice]:
"""
Finds all matching voices based on the provided attributes.
"""