@@ -5,7 +5,6 @@ Communicate package.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
import re
|
|
||||||
import ssl
|
import ssl
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@@ -28,37 +27,38 @@ from xml.sax.saxutils import escape
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import certifi
|
import certifi
|
||||||
|
|
||||||
from edge_tts.exceptions import (
|
from .constants import WSS_URL
|
||||||
|
from .exceptions import (
|
||||||
NoAudioReceived,
|
NoAudioReceived,
|
||||||
UnexpectedResponse,
|
UnexpectedResponse,
|
||||||
UnknownResponse,
|
UnknownResponse,
|
||||||
WebSocketError,
|
WebSocketError,
|
||||||
)
|
)
|
||||||
|
from .models import TTSConfig
|
||||||
from .constants import WSS_URL
|
|
||||||
|
|
||||||
|
|
||||||
def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[bytes, bytes], bytes]:
|
def get_headers_and_data(
|
||||||
|
data: bytes, header_length: int
|
||||||
|
) -> Tuple[Dict[bytes, bytes], bytes]:
|
||||||
"""
|
"""
|
||||||
Returns the headers and data from the given data.
|
Returns the headers and data from the given data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (str or bytes): The data to be parsed.
|
data (bytes): The data to be parsed.
|
||||||
|
header_length (int): The length of the header.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: The headers and data to be used in the request.
|
tuple: The headers and data to be used in the request.
|
||||||
"""
|
"""
|
||||||
if isinstance(data, str):
|
|
||||||
data = data.encode("utf-8")
|
|
||||||
if not isinstance(data, bytes):
|
if not isinstance(data, bytes):
|
||||||
raise TypeError("data must be str or bytes")
|
raise TypeError("data must be bytes")
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
for line in data[: data.find(b"\r\n\r\n")].split(b"\r\n"):
|
for line in data[:header_length].split(b"\r\n"):
|
||||||
key, value = line.split(b":", 1)
|
key, value = line.split(b":", 1)
|
||||||
headers[key] = value
|
headers[key] = value
|
||||||
|
|
||||||
return headers, data[data.find(b"\r\n\r\n") + 4 :]
|
return headers, data[header_length + 2 :]
|
||||||
|
|
||||||
|
|
||||||
def remove_incompatible_characters(string: Union[str, bytes]) -> str:
|
def remove_incompatible_characters(string: Union[str, bytes]) -> str:
|
||||||
@@ -154,24 +154,32 @@ def split_text_by_byte_length(
|
|||||||
yield new_text
|
yield new_text
|
||||||
|
|
||||||
|
|
||||||
def mkssml(
|
def mkssml(tc: TTSConfig, escaped_text: Union[str, bytes]) -> str:
|
||||||
text: Union[str, bytes], voice: str, rate: str, volume: str, pitch: str
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Creates a SSML string from the given parameters.
|
Creates a SSML string from the given parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tc (TTSConfig): The TTS configuration.
|
||||||
|
escaped_text (str or bytes): The escaped text. If bytes, it must be UTF-8 encoded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The SSML string.
|
str: The SSML string.
|
||||||
"""
|
"""
|
||||||
if isinstance(text, bytes):
|
|
||||||
text = text.decode("utf-8")
|
|
||||||
|
|
||||||
ssml = (
|
# If the text is bytes, convert it to a string.
|
||||||
|
if isinstance(escaped_text, bytes):
|
||||||
|
escaped_text = escaped_text.decode("utf-8")
|
||||||
|
|
||||||
|
# Return the SSML string.
|
||||||
|
return (
|
||||||
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>"
|
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>"
|
||||||
f"<voice name='{voice}'><prosody pitch='{pitch}' rate='{rate}' volume='{volume}'>"
|
f"<voice name='{tc.voice}'>"
|
||||||
f"{text}</prosody></voice></speak>"
|
f"<prosody pitch='{tc.pitch}' rate='{tc.rate}' volume='{tc.volume}'>"
|
||||||
|
f"{escaped_text}"
|
||||||
|
"</prosody>"
|
||||||
|
"</voice>"
|
||||||
|
"</speak>"
|
||||||
)
|
)
|
||||||
return ssml
|
|
||||||
|
|
||||||
|
|
||||||
def date_to_string() -> str:
|
def date_to_string() -> str:
|
||||||
@@ -207,7 +215,7 @@ def ssml_headers_plus_data(request_id: str, timestamp: str, ssml: str) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def calc_max_mesg_size(voice: str, rate: str, volume: str, pitch: str) -> int:
|
def calc_max_mesg_size(tts_config: TTSConfig) -> int:
|
||||||
"""Calculates the maximum message size for the given voice, rate, and volume.
|
"""Calculates the maximum message size for the given voice, rate, and volume.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -219,7 +227,7 @@ def calc_max_mesg_size(voice: str, rate: str, volume: str, pitch: str) -> int:
|
|||||||
ssml_headers_plus_data(
|
ssml_headers_plus_data(
|
||||||
connect_id(),
|
connect_id(),
|
||||||
date_to_string(),
|
date_to_string(),
|
||||||
mkssml("", voice, rate, volume, pitch),
|
mkssml(tts_config, ""),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
+ 50 # margin of error
|
+ 50 # margin of error
|
||||||
@@ -232,25 +240,6 @@ class Communicate:
|
|||||||
Class for communicating with the service.
|
Class for communicating with the service.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def validate_string_param(param_name: str, param_value: str, pattern: str) -> str:
|
|
||||||
"""
|
|
||||||
Validates the given string parameter based on type and pattern.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
param_name (str): The name of the parameter.
|
|
||||||
param_value (str): The value of the parameter.
|
|
||||||
pattern (str): The pattern to validate the parameter against.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The validated parameter.
|
|
||||||
"""
|
|
||||||
if not isinstance(param_value, str):
|
|
||||||
raise TypeError(f"{param_name} must be str")
|
|
||||||
if re.match(pattern, param_value) is None:
|
|
||||||
raise ValueError(f"Invalid {param_name} '{param_value}'.")
|
|
||||||
return param_value
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -269,46 +258,30 @@ class Communicate:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the voice is not valid.
|
ValueError: If the voice is not valid.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Validate TTS settings and store the TTSConfig object.
|
||||||
|
self.tts_config = TTSConfig(voice, rate, volume, pitch)
|
||||||
|
|
||||||
|
# Validate the text parameter.
|
||||||
if not isinstance(text, str):
|
if not isinstance(text, str):
|
||||||
raise TypeError("text must be str")
|
raise TypeError("text must be str")
|
||||||
self.text: str = text
|
|
||||||
|
|
||||||
# Possible values for voice are:
|
# Split the text into multiple strings and store them.
|
||||||
# - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural)
|
self.texts = split_text_by_byte_length(
|
||||||
# - cy-GB-NiaNeural
|
escape(remove_incompatible_characters(text)),
|
||||||
# - fil-PH-AngeloNeural
|
calc_max_mesg_size(self.tts_config),
|
||||||
# Always send the first variant as that is what Microsoft Edge does.
|
|
||||||
if not isinstance(voice, str):
|
|
||||||
raise TypeError("voice must be str")
|
|
||||||
self.voice: str = voice
|
|
||||||
match = re.match(r"^([a-z]{2,})-([A-Z]{2,})-(.+Neural)$", voice)
|
|
||||||
if match is not None:
|
|
||||||
lang = match.group(1)
|
|
||||||
region = match.group(2)
|
|
||||||
name = match.group(3)
|
|
||||||
if name.find("-") != -1:
|
|
||||||
region = region + "-" + name[: name.find("-")]
|
|
||||||
name = name[name.find("-") + 1 :]
|
|
||||||
self.voice = (
|
|
||||||
"Microsoft Server Speech Text to Speech Voice"
|
|
||||||
+ f" ({lang}-{region}, {name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.voice = self.validate_string_param(
|
|
||||||
"voice",
|
|
||||||
self.voice,
|
|
||||||
r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$",
|
|
||||||
)
|
)
|
||||||
self.rate = self.validate_string_param("rate", rate, r"^[+-]\d+%$")
|
|
||||||
self.volume = self.validate_string_param("volume", volume, r"^[+-]\d+%$")
|
|
||||||
self.pitch = self.validate_string_param("pitch", pitch, r"^[+-]\d+Hz$")
|
|
||||||
|
|
||||||
|
# Validate the proxy parameter.
|
||||||
if proxy is not None and not isinstance(proxy, str):
|
if proxy is not None and not isinstance(proxy, str):
|
||||||
raise TypeError("proxy must be str")
|
raise TypeError("proxy must be str")
|
||||||
self.proxy: Optional[str] = proxy
|
self.proxy: Optional[str] = proxy
|
||||||
|
|
||||||
if not isinstance(connect_timeout, int) or not isinstance(receive_timeout, int):
|
# Validate the timeout parameters.
|
||||||
raise TypeError("connect_timeout and receive_timeout must be int")
|
if not isinstance(connect_timeout, int):
|
||||||
|
raise TypeError("connect_timeout must be int")
|
||||||
|
if not isinstance(receive_timeout, int):
|
||||||
|
raise TypeError("receive_timeout must be int")
|
||||||
self.session_timeout = aiohttp.ClientTimeout(
|
self.session_timeout = aiohttp.ClientTimeout(
|
||||||
total=None,
|
total=None,
|
||||||
connect=None,
|
connect=None,
|
||||||
@@ -316,9 +289,34 @@ class Communicate:
|
|||||||
sock_read=receive_timeout,
|
sock_read=receive_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream(self) -> AsyncGenerator[Dict[str, Any], None]:
|
# Store current state of TTS.
|
||||||
"""Streams audio and metadata from the service."""
|
self.state: Dict[str, Any] = {
|
||||||
|
"partial_text": None,
|
||||||
|
"offset_compensation": 0,
|
||||||
|
"last_duration_offset": 0,
|
||||||
|
"stream_was_called": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __parse_metadata(self, data: bytes) -> Dict[str, Any]:
|
||||||
|
for meta_obj in json.loads(data)["Metadata"]:
|
||||||
|
meta_type = meta_obj["Type"]
|
||||||
|
if meta_type == "WordBoundary":
|
||||||
|
current_offset = (
|
||||||
|
meta_obj["Data"]["Offset"] + self.state["offset_compensation"]
|
||||||
|
)
|
||||||
|
current_duration = meta_obj["Data"]["Duration"]
|
||||||
|
return {
|
||||||
|
"type": meta_type,
|
||||||
|
"offset": current_offset,
|
||||||
|
"duration": current_duration,
|
||||||
|
"text": meta_obj["Data"]["text"]["Text"],
|
||||||
|
}
|
||||||
|
if meta_type in ("SessionEnd",):
|
||||||
|
continue
|
||||||
|
raise UnknownResponse(f"Unknown metadata type: {meta_type}")
|
||||||
|
raise UnexpectedResponse("No WordBoundary metadata found")
|
||||||
|
|
||||||
|
async def __stream(self) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
async def send_command_request() -> None:
|
async def send_command_request() -> None:
|
||||||
"""Sends the request to the service."""
|
"""Sends the request to the service."""
|
||||||
|
|
||||||
@@ -342,55 +340,25 @@ class Communicate:
|
|||||||
"}}}}\r\n"
|
"}}}}\r\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_ssml_request() -> bool:
|
async def send_ssml_request() -> None:
|
||||||
"""Sends the SSML request to the service."""
|
"""Sends the SSML request to the service."""
|
||||||
|
|
||||||
# Get the next string from the generator.
|
# Send the request to the service.
|
||||||
text = next(texts, None)
|
|
||||||
|
|
||||||
# If there are no more strings, return False.
|
|
||||||
if text is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Send the request to the service and return True.
|
|
||||||
await websocket.send_str(
|
await websocket.send_str(
|
||||||
ssml_headers_plus_data(
|
ssml_headers_plus_data(
|
||||||
connect_id(),
|
connect_id(),
|
||||||
date_to_string(),
|
date_to_string(),
|
||||||
mkssml(text, self.voice, self.rate, self.volume, self.pitch),
|
mkssml(
|
||||||
|
self.tts_config,
|
||||||
|
self.state["partial_text"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return True
|
|
||||||
|
|
||||||
def parse_metadata() -> Dict[str, Any]:
|
# audio_was_received indicates whether we have received audio data
|
||||||
for meta_obj in json.loads(data)["Metadata"]:
|
# from the websocket. This is so we can raise an exception if we
|
||||||
meta_type = meta_obj["Type"]
|
# don't receive any audio data.
|
||||||
if meta_type == "WordBoundary":
|
audio_was_received = False
|
||||||
current_offset = meta_obj["Data"]["Offset"] + offset_compensation
|
|
||||||
current_duration = meta_obj["Data"]["Duration"]
|
|
||||||
return {
|
|
||||||
"type": meta_type,
|
|
||||||
"offset": current_offset,
|
|
||||||
"duration": current_duration,
|
|
||||||
"text": meta_obj["Data"]["text"]["Text"],
|
|
||||||
}
|
|
||||||
if meta_type in ("SessionEnd",):
|
|
||||||
continue
|
|
||||||
raise UnknownResponse(f"Unknown metadata type: {meta_type}")
|
|
||||||
raise UnexpectedResponse("No WordBoundary metadata found")
|
|
||||||
|
|
||||||
# Split the text into multiple strings if it is too long for the service.
|
|
||||||
texts = split_text_by_byte_length(
|
|
||||||
escape(remove_incompatible_characters(self.text)),
|
|
||||||
calc_max_mesg_size(self.voice, self.rate, self.volume, self.pitch),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Keep track of last duration + offset to calculate the offset
|
|
||||||
# upon word split.
|
|
||||||
last_duration_offset = 0
|
|
||||||
|
|
||||||
# Current offset compensations.
|
|
||||||
offset_compensation = 0
|
|
||||||
|
|
||||||
# Create a new connection to the service.
|
# Create a new connection to the service.
|
||||||
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
|
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
|
||||||
@@ -412,11 +380,6 @@ class Communicate:
|
|||||||
},
|
},
|
||||||
ssl=ssl_ctx,
|
ssl=ssl_ctx,
|
||||||
) as websocket:
|
) as websocket:
|
||||||
# audio_was_received indicates whether we have received audio data
|
|
||||||
# from the websocket. This is so we can raise an exception if we
|
|
||||||
# don't receive any audio data.
|
|
||||||
audio_was_received = False
|
|
||||||
|
|
||||||
# Send the request to the service.
|
# Send the request to the service.
|
||||||
await send_command_request()
|
await send_command_request()
|
||||||
|
|
||||||
@@ -425,53 +388,91 @@ class Communicate:
|
|||||||
|
|
||||||
async for received in websocket:
|
async for received in websocket:
|
||||||
if received.type == aiohttp.WSMsgType.TEXT:
|
if received.type == aiohttp.WSMsgType.TEXT:
|
||||||
parameters, data = get_headers_and_data(received.data)
|
encoded_data: bytes = received.data.encode("utf-8")
|
||||||
path = parameters.get(b"Path")
|
parameters, data = get_headers_and_data(
|
||||||
|
encoded_data, encoded_data.find(b"\r\n\r\n")
|
||||||
|
)
|
||||||
|
|
||||||
|
path = parameters.get(b"Path", None)
|
||||||
if path == b"audio.metadata":
|
if path == b"audio.metadata":
|
||||||
# Parse the metadata and yield it.
|
# Parse the metadata and yield it.
|
||||||
parsed_metadata = parse_metadata()
|
parsed_metadata = self.__parse_metadata(data)
|
||||||
yield parsed_metadata
|
yield parsed_metadata
|
||||||
|
|
||||||
# Update the last duration offset for use by the next SSML request.
|
# Update the last duration offset for use by the next SSML request.
|
||||||
last_duration_offset = (
|
self.state["last_duration_offset"] = (
|
||||||
parsed_metadata["offset"] + parsed_metadata["duration"]
|
parsed_metadata["offset"] + parsed_metadata["duration"]
|
||||||
)
|
)
|
||||||
elif path == b"turn.end":
|
elif path == b"turn.end":
|
||||||
# Update the offset compensation for the next SSML request.
|
# Update the offset compensation for the next SSML request.
|
||||||
offset_compensation = last_duration_offset
|
self.state["offset_compensation"] = self.state[
|
||||||
|
"last_duration_offset"
|
||||||
|
]
|
||||||
|
|
||||||
# Use average padding typically added by the service
|
# Use average padding typically added by the service
|
||||||
# to the end of the audio data. This seems to work pretty
|
# to the end of the audio data. This seems to work pretty
|
||||||
# well for now, but we might ultimately need to use a
|
# well for now, but we might ultimately need to use a
|
||||||
# more sophisticated method like using ffmpeg to get
|
# more sophisticated method like using ffmpeg to get
|
||||||
# the actual duration of the audio data.
|
# the actual duration of the audio data.
|
||||||
offset_compensation += 8_750_000
|
self.state["offset_compensation"] += 8_750_000
|
||||||
|
|
||||||
# Send the next SSML request to the service.
|
# Exit the loop so we can send the next SSML request.
|
||||||
if not await send_ssml_request():
|
break
|
||||||
break
|
|
||||||
elif path not in (b"response", b"turn.start"):
|
elif path not in (b"response", b"turn.start"):
|
||||||
raise UnknownResponse(
|
raise UnknownResponse("Unknown path received")
|
||||||
"The response from the service is not recognized.\n"
|
|
||||||
+ received.data
|
|
||||||
)
|
|
||||||
elif received.type == aiohttp.WSMsgType.BINARY:
|
elif received.type == aiohttp.WSMsgType.BINARY:
|
||||||
|
# Message is too short to contain header length.
|
||||||
if len(received.data) < 2:
|
if len(received.data) < 2:
|
||||||
raise UnexpectedResponse(
|
raise UnexpectedResponse(
|
||||||
"We received a binary message, but it is missing the header length."
|
"We received a binary message, but it is missing the header length."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The first two bytes of the binary message contain the header length.
|
||||||
header_length = int.from_bytes(received.data[:2], "big")
|
header_length = int.from_bytes(received.data[:2], "big")
|
||||||
if len(received.data) < header_length + 2:
|
if header_length > len(received.data):
|
||||||
raise UnexpectedResponse(
|
raise UnexpectedResponse(
|
||||||
"We received a binary message, but it is missing the audio data."
|
"The header length is greater than the length of the data."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Parse the headers and data from the binary message.
|
||||||
|
parameters, data = get_headers_and_data(
|
||||||
|
received.data, header_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the path is audio.
|
||||||
|
if parameters.get(b"Path") != b"audio":
|
||||||
|
raise UnexpectedResponse(
|
||||||
|
"Received binary message, but the path is not audio."
|
||||||
|
)
|
||||||
|
|
||||||
|
# At termination of the stream, the service sends a binary message
|
||||||
|
# with no Content-Type; this is expected. What is not expected is for
|
||||||
|
# an MPEG audio stream to be sent with no data.
|
||||||
|
content_type = parameters.get(b"Content-Type", None)
|
||||||
|
if content_type not in [b"audio/mpeg", None]:
|
||||||
|
raise UnexpectedResponse(
|
||||||
|
"Received binary message, but with an unexpected Content-Type."
|
||||||
|
)
|
||||||
|
|
||||||
|
# We only allow no Content-Type if there is no data.
|
||||||
|
if content_type is None:
|
||||||
|
if len(data) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If the data is not empty, then we need to raise an exception.
|
||||||
|
raise UnexpectedResponse(
|
||||||
|
"Received binary message with no Content-Type, but with data."
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the data is empty now, then we need to raise an exception.
|
||||||
|
if len(data) == 0:
|
||||||
|
raise UnexpectedResponse(
|
||||||
|
"Received binary message, but it is missing the audio data."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield the audio data.
|
||||||
audio_was_received = True
|
audio_was_received = True
|
||||||
yield {
|
yield {"type": "audio", "data": data}
|
||||||
"type": "audio",
|
|
||||||
"data": received.data[header_length + 2 :],
|
|
||||||
}
|
|
||||||
elif received.type == aiohttp.WSMsgType.ERROR:
|
elif received.type == aiohttp.WSMsgType.ERROR:
|
||||||
raise WebSocketError(
|
raise WebSocketError(
|
||||||
received.data if received.data else "Unknown error"
|
received.data if received.data else "Unknown error"
|
||||||
@@ -482,6 +483,29 @@ class Communicate:
|
|||||||
"No audio was received. Please verify that your parameters are correct."
|
"No audio was received. Please verify that your parameters are correct."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def stream(
|
||||||
|
self,
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
Streams audio and metadata from the service.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NoAudioReceived: If no audio is received from the service.
|
||||||
|
UnexpectedResponse: If the response from the service is unexpected.
|
||||||
|
UnknownResponse: If the response from the service is unknown.
|
||||||
|
WebSocketError: If there is an error with the websocket.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if stream was called before.
|
||||||
|
if self.state["stream_was_called"]:
|
||||||
|
raise RuntimeError("stream can only be called once.")
|
||||||
|
self.state["stream_was_called"] = True
|
||||||
|
|
||||||
|
# Stream the audio and metadata from the service.
|
||||||
|
for self.state["partial_text"] in self.texts:
|
||||||
|
async for message in self.__stream():
|
||||||
|
yield message
|
||||||
|
|
||||||
async def save(
|
async def save(
|
||||||
self,
|
self,
|
||||||
audio_fname: Union[str, bytes],
|
audio_fname: Union[str, bytes],
|
||||||
|
|||||||
70
src/edge_tts/models.py
Normal file
70
src/edge_tts/models.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""Models for the Edge TTS module."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TTSConfig:
|
||||||
|
"""
|
||||||
|
Represents the internal TTS configuration for Edge TTS's communicate class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
voice: str
|
||||||
|
rate: str
|
||||||
|
volume: str
|
||||||
|
pitch: str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_string_param(param_name: str, param_value: str, pattern: str) -> str:
|
||||||
|
"""
|
||||||
|
Validates the given string parameter based on type and pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_name (str): The name of the parameter.
|
||||||
|
param_value (str): The value of the parameter.
|
||||||
|
pattern (str): The pattern to validate the parameter against.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The validated parameter.
|
||||||
|
"""
|
||||||
|
if not isinstance(param_value, str):
|
||||||
|
raise TypeError(f"{param_name} must be str")
|
||||||
|
if re.match(pattern, param_value) is None:
|
||||||
|
raise ValueError(f"Invalid {param_name} '{param_value}'.")
|
||||||
|
return param_value
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Validates the TTSConfig object after initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Possible values for voice are:
|
||||||
|
# - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural)
|
||||||
|
# - cy-GB-NiaNeural
|
||||||
|
# - fil-PH-AngeloNeural
|
||||||
|
# Always send the first variant as that is what Microsoft Edge does.
|
||||||
|
if not isinstance(self.voice, str):
|
||||||
|
raise TypeError("voice must be str")
|
||||||
|
match = re.match(r"^([a-z]{2,})-([A-Z]{2,})-(.+Neural)$", self.voice)
|
||||||
|
if match is not None:
|
||||||
|
lang = match.group(1)
|
||||||
|
region = match.group(2)
|
||||||
|
name = match.group(3)
|
||||||
|
if name.find("-") != -1:
|
||||||
|
region = region + "-" + name[: name.find("-")]
|
||||||
|
name = name[name.find("-") + 1 :]
|
||||||
|
self.voice = (
|
||||||
|
"Microsoft Server Speech Text to Speech Voice"
|
||||||
|
+ f" ({lang}-{region}, {name})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate the rate, volume, and pitch parameters.
|
||||||
|
self.validate_string_param(
|
||||||
|
"voice",
|
||||||
|
self.voice,
|
||||||
|
r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$",
|
||||||
|
)
|
||||||
|
self.validate_string_param("rate", self.rate, r"^[+-]\d+%$")
|
||||||
|
self.validate_string_param("volume", self.volume, r"^[+-]\d+%$")
|
||||||
|
self.validate_string_param("pitch", self.pitch, r"^[+-]\d+Hz$")
|
||||||
@@ -8,7 +8,7 @@ import sys
|
|||||||
from io import TextIOWrapper
|
from io import TextIOWrapper
|
||||||
from typing import Any, TextIO, Union
|
from typing import Any, TextIO, Union
|
||||||
|
|
||||||
from edge_tts import Communicate, SubMaker, list_voices
|
from . import Communicate, SubMaker, list_voices
|
||||||
|
|
||||||
|
|
||||||
async def _print_voices(*, proxy: str) -> None:
|
async def _print_voices(*, proxy: str) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user