Use TypedDict for voice list and TTS stream (#320)
Signed-off-by: rany <rany2@riseup.net>
This commit is contained in:
1
setup.py
1
setup.py
@@ -6,5 +6,6 @@ setup(
|
|||||||
install_requires=[
|
install_requires=[
|
||||||
"aiohttp>=3.8.0",
|
"aiohttp>=3.8.0",
|
||||||
"certifi>=2023.11.17",
|
"certifi>=2023.11.17",
|
||||||
|
"typing-extensions>=4.1.0",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from .exceptions import (
|
|||||||
WebSocketError,
|
WebSocketError,
|
||||||
)
|
)
|
||||||
from .models import TTSConfig
|
from .models import TTSConfig
|
||||||
|
from .typing import TTSChunk
|
||||||
|
|
||||||
|
|
||||||
def get_headers_and_data(
|
def get_headers_and_data(
|
||||||
@@ -297,7 +298,7 @@ class Communicate:
|
|||||||
"stream_was_called": False,
|
"stream_was_called": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __parse_metadata(self, data: bytes) -> Dict[str, Any]:
|
def __parse_metadata(self, data: bytes) -> TTSChunk:
|
||||||
for meta_obj in json.loads(data)["Metadata"]:
|
for meta_obj in json.loads(data)["Metadata"]:
|
||||||
meta_type = meta_obj["Type"]
|
meta_type = meta_obj["Type"]
|
||||||
if meta_type == "WordBoundary":
|
if meta_type == "WordBoundary":
|
||||||
@@ -316,7 +317,7 @@ class Communicate:
|
|||||||
raise UnknownResponse(f"Unknown metadata type: {meta_type}")
|
raise UnknownResponse(f"Unknown metadata type: {meta_type}")
|
||||||
raise UnexpectedResponse("No WordBoundary metadata found")
|
raise UnexpectedResponse("No WordBoundary metadata found")
|
||||||
|
|
||||||
async def __stream(self) -> AsyncGenerator[Dict[str, Any], None]:
|
async def __stream(self) -> AsyncGenerator[TTSChunk, None]:
|
||||||
async def send_command_request() -> None:
|
async def send_command_request() -> None:
|
||||||
"""Sends the request to the service."""
|
"""Sends the request to the service."""
|
||||||
|
|
||||||
@@ -479,7 +480,7 @@ class Communicate:
|
|||||||
|
|
||||||
async def stream(
|
async def stream(
|
||||||
self,
|
self,
|
||||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
) -> AsyncGenerator[TTSChunk, None]:
|
||||||
"""
|
"""
|
||||||
Streams audio and metadata from the service.
|
Streams audio and metadata from the service.
|
||||||
|
|
||||||
|
|||||||
84
src/edge_tts/typing.py
Normal file
84
src/edge_tts/typing.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Custom types for edge-tts."""
|
||||||
|
|
||||||
|
# pylint: disable=too-few-public-methods
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from typing_extensions import Literal, NotRequired, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class TTSChunk(TypedDict):
|
||||||
|
"""TTS chunk data."""
|
||||||
|
|
||||||
|
type: Literal["audio", "WordBoundary"]
|
||||||
|
data: NotRequired[bytes] # only for audio
|
||||||
|
duration: NotRequired[float] # only for WordBoundary
|
||||||
|
offset: NotRequired[float] # only for WordBoundary
|
||||||
|
text: NotRequired[str] # only for WordBoundary
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceTag(TypedDict):
|
||||||
|
"""VoiceTag data."""
|
||||||
|
|
||||||
|
ContentCategories: List[
|
||||||
|
Literal[
|
||||||
|
"Cartoon",
|
||||||
|
"Conversation",
|
||||||
|
"Copilot",
|
||||||
|
"Dialect",
|
||||||
|
"General",
|
||||||
|
"News",
|
||||||
|
"Novel",
|
||||||
|
"Sports",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
VoicePersonalities: List[
|
||||||
|
Literal[
|
||||||
|
"Approachable",
|
||||||
|
"Authentic",
|
||||||
|
"Authority",
|
||||||
|
"Bright",
|
||||||
|
"Caring",
|
||||||
|
"Casual",
|
||||||
|
"Cheerful",
|
||||||
|
"Clear",
|
||||||
|
"Comfort",
|
||||||
|
"Confident",
|
||||||
|
"Considerate",
|
||||||
|
"Conversational",
|
||||||
|
"Cute",
|
||||||
|
"Expressive",
|
||||||
|
"Friendly",
|
||||||
|
"Honest",
|
||||||
|
"Humorous",
|
||||||
|
"Lively",
|
||||||
|
"Passion",
|
||||||
|
"Pleasant",
|
||||||
|
"Positive",
|
||||||
|
"Professional",
|
||||||
|
"Rational",
|
||||||
|
"Reliable",
|
||||||
|
"Sincere",
|
||||||
|
"Sunshine",
|
||||||
|
"Warm",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Voice(TypedDict):
|
||||||
|
"""Voice data."""
|
||||||
|
|
||||||
|
Name: str
|
||||||
|
ShortName: str
|
||||||
|
Gender: Literal["Female", "Male"]
|
||||||
|
Locale: str
|
||||||
|
SuggestedCodec: Literal["audio-24khz-48kbitrate-mono-mp3"]
|
||||||
|
FriendlyName: str
|
||||||
|
Status: Literal["GA"]
|
||||||
|
VoiceTag: VoiceTag
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceManagerVoice(Voice):
|
||||||
|
"""Voice data for VoiceManager."""
|
||||||
|
|
||||||
|
Language: str
|
||||||
@@ -3,18 +3,19 @@ correct voice based on their attributes."""
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import ssl
|
import ssl
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import certifi
|
import certifi
|
||||||
|
|
||||||
from .constants import SEC_MS_GEC_VERSION, VOICE_HEADERS, VOICE_LIST
|
from .constants import SEC_MS_GEC_VERSION, VOICE_HEADERS, VOICE_LIST
|
||||||
from .drm import DRM
|
from .drm import DRM
|
||||||
|
from .typing import Voice, VoiceManagerVoice
|
||||||
|
|
||||||
|
|
||||||
async def __list_voices(
|
async def __list_voices(
|
||||||
session: aiohttp.ClientSession, ssl_ctx: ssl.SSLContext, proxy: Optional[str]
|
session: aiohttp.ClientSession, ssl_ctx: ssl.SSLContext, proxy: Optional[str]
|
||||||
) -> Any:
|
) -> List[Voice]:
|
||||||
"""
|
"""
|
||||||
Private function that makes the request to the voice list URL and parses the
|
Private function that makes the request to the voice list URL and parses the
|
||||||
JSON response. This function is used by list_voices() and makes it easier to
|
JSON response. This function is used by list_voices() and makes it easier to
|
||||||
@@ -26,7 +27,7 @@ async def __list_voices(
|
|||||||
proxy (Optional[str]): The proxy to use for the request.
|
proxy (Optional[str]): The proxy to use for the request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary of voice attributes.
|
List[Voice]: A list of voices and their attributes.
|
||||||
"""
|
"""
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"{VOICE_LIST}&Sec-MS-GEC={DRM.generate_sec_ms_gec()}"
|
f"{VOICE_LIST}&Sec-MS-GEC={DRM.generate_sec_ms_gec()}"
|
||||||
@@ -36,11 +37,25 @@ async def __list_voices(
|
|||||||
ssl=ssl_ctx,
|
ssl=ssl_ctx,
|
||||||
raise_for_status=True,
|
raise_for_status=True,
|
||||||
) as url:
|
) as url:
|
||||||
data = json.loads(await url.text())
|
data: List[Voice] = json.loads(await url.text())
|
||||||
|
|
||||||
|
for voice in data:
|
||||||
|
# Remove leading and trailing whitespace from categories and personalities.
|
||||||
|
# This has only happened in one case with the zh-CN-YunjianNeural voice
|
||||||
|
# where there was a leading space in one of the categories.
|
||||||
|
voice["VoiceTag"]["ContentCategories"] = [
|
||||||
|
category.strip() # type: ignore
|
||||||
|
for category in voice["VoiceTag"]["ContentCategories"]
|
||||||
|
]
|
||||||
|
voice["VoiceTag"]["VoicePersonalities"] = [
|
||||||
|
personality.strip() # type: ignore
|
||||||
|
for personality in voice["VoiceTag"]["VoicePersonalities"]
|
||||||
|
]
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def list_voices(*, proxy: Optional[str] = None) -> Any:
|
async def list_voices(*, proxy: Optional[str] = None) -> List[Voice]:
|
||||||
"""
|
"""
|
||||||
List all available voices and their attributes.
|
List all available voices and their attributes.
|
||||||
|
|
||||||
@@ -51,7 +66,7 @@ async def list_voices(*, proxy: Optional[str] = None) -> Any:
|
|||||||
proxy (Optional[str]): The proxy to use for the request.
|
proxy (Optional[str]): The proxy to use for the request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary of voice attributes.
|
List[Voice]: A list of voices and their attributes.
|
||||||
"""
|
"""
|
||||||
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
|
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
|
||||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||||
@@ -72,26 +87,23 @@ class VoicesManager:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.voices: List[Dict[str, Any]] = []
|
self.voices: List[VoiceManagerVoice] = []
|
||||||
self.called_create: bool = False
|
self.called_create: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def create(cls: Any, custom_voices: Optional[List[Voice]] = None) -> Any:
|
||||||
cls: Any, custom_voices: Optional[List[Dict[str, Any]]] = None
|
|
||||||
) -> Any:
|
|
||||||
"""
|
"""
|
||||||
Creates a VoicesManager object and populates it with all available voices.
|
Creates a VoicesManager object and populates it with all available voices.
|
||||||
"""
|
"""
|
||||||
self = VoicesManager()
|
self = VoicesManager()
|
||||||
self.voices = await list_voices() if custom_voices is None else custom_voices
|
voices = await list_voices() if custom_voices is None else custom_voices
|
||||||
self.voices = [
|
self.voices = [
|
||||||
{**voice, **{"Language": voice["Locale"].split("-")[0]}}
|
{**voice, "Language": voice["Locale"].split("-")[0]} for voice in voices
|
||||||
for voice in self.voices
|
|
||||||
]
|
]
|
||||||
self.called_create = True
|
self.called_create = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def find(self, **kwargs: Any) -> List[Dict[str, Any]]:
|
def find(self, **kwargs: Any) -> List[VoiceManagerVoice]:
|
||||||
"""
|
"""
|
||||||
Finds all matching voices based on the provided attributes.
|
Finds all matching voices based on the provided attributes.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user