* Add sync versions of stream and save methods In order to provide synchronous interface to the library * Fix save_sync() failing to use metadata_fname and fix typing issues Signed-off-by: rany <ranygh@riseup.net> --------- Signed-off-by: rany <ranygh@riseup.net> Co-authored-by: rany <ranygh@riseup.net>
541 lines
18 KiB
Python
541 lines
18 KiB
Python
"""
|
|
Communicate package.
|
|
"""
|
|
|
|
import asyncio
|
|
import concurrent.futures
|
|
import json
|
|
import re
|
|
import ssl
|
|
import time
|
|
import uuid
|
|
from contextlib import nullcontext
|
|
from io import TextIOWrapper
|
|
from queue import Queue
|
|
from typing import (
|
|
Any,
|
|
AsyncGenerator,
|
|
ContextManager,
|
|
Dict,
|
|
Generator,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
from xml.sax.saxutils import escape
|
|
|
|
import aiohttp
|
|
import certifi
|
|
|
|
from edge_tts.exceptions import (
|
|
NoAudioReceived,
|
|
UnexpectedResponse,
|
|
UnknownResponse,
|
|
WebSocketError,
|
|
)
|
|
|
|
from .constants import WSS_URL
|
|
|
|
|
|
def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[bytes, bytes], bytes]:
|
|
"""
|
|
Returns the headers and data from the given data.
|
|
|
|
Args:
|
|
data (str or bytes): The data to be parsed.
|
|
|
|
Returns:
|
|
tuple: The headers and data to be used in the request.
|
|
"""
|
|
if isinstance(data, str):
|
|
data = data.encode("utf-8")
|
|
if not isinstance(data, bytes):
|
|
raise TypeError("data must be str or bytes")
|
|
|
|
headers = {}
|
|
for line in data[: data.find(b"\r\n\r\n")].split(b"\r\n"):
|
|
key, value = line.split(b":", 1)
|
|
headers[key] = value
|
|
|
|
return headers, data[data.find(b"\r\n\r\n") + 4 :]
|
|
|
|
|
|
def remove_incompatible_characters(string: Union[str, bytes]) -> str:
|
|
"""
|
|
The service does not support a couple character ranges.
|
|
Most important being the vertical tab character which is
|
|
commonly present in OCR-ed PDFs. Not doing this will
|
|
result in an error from the service.
|
|
|
|
Args:
|
|
string (str or bytes): The string to be cleaned.
|
|
|
|
Returns:
|
|
str: The cleaned string.
|
|
"""
|
|
if isinstance(string, bytes):
|
|
string = string.decode("utf-8")
|
|
if not isinstance(string, str):
|
|
raise TypeError("string must be str or bytes")
|
|
|
|
chars: List[str] = list(string)
|
|
|
|
for idx, char in enumerate(chars):
|
|
code: int = ord(char)
|
|
if (0 <= code <= 8) or (11 <= code <= 12) or (14 <= code <= 31):
|
|
chars[idx] = " "
|
|
|
|
return "".join(chars)
|
|
|
|
|
|
def connect_id() -> str:
|
|
"""
|
|
Returns a UUID without dashes.
|
|
|
|
Returns:
|
|
str: A UUID without dashes.
|
|
"""
|
|
return str(uuid.uuid4()).replace("-", "")
|
|
|
|
|
|
def split_text_by_byte_length(
|
|
text: Union[str, bytes], byte_length: int
|
|
) -> Generator[bytes, None, None]:
|
|
"""
|
|
Splits a string into a list of strings of a given byte length
|
|
while attempting to keep words together. This function assumes
|
|
text will be inside of an XML tag.
|
|
|
|
Args:
|
|
text (str or bytes): The string to be split.
|
|
byte_length (int): The maximum byte length of each string in the list.
|
|
|
|
Yield:
|
|
bytes: The next string in the list.
|
|
"""
|
|
if isinstance(text, str):
|
|
text = text.encode("utf-8")
|
|
if not isinstance(text, bytes):
|
|
raise TypeError("text must be str or bytes")
|
|
|
|
if byte_length <= 0:
|
|
raise ValueError("byte_length must be greater than 0")
|
|
|
|
while len(text) > byte_length:
|
|
# Find the last space in the string
|
|
split_at = text.rfind(b" ", 0, byte_length)
|
|
|
|
# If no space found, split_at is byte_length
|
|
split_at = split_at if split_at != -1 else byte_length
|
|
|
|
# Verify all & are terminated with a ;
|
|
while b"&" in text[:split_at]:
|
|
ampersand_index = text.rindex(b"&", 0, split_at)
|
|
if text.find(b";", ampersand_index, split_at) != -1:
|
|
break
|
|
|
|
split_at = ampersand_index - 1
|
|
if split_at < 0:
|
|
raise ValueError("Maximum byte length is too small or invalid text")
|
|
if split_at == 0:
|
|
break
|
|
|
|
# Append the string to the list
|
|
new_text = text[:split_at].strip()
|
|
if new_text:
|
|
yield new_text
|
|
if split_at == 0:
|
|
split_at = 1
|
|
text = text[split_at:]
|
|
|
|
new_text = text.strip()
|
|
if new_text:
|
|
yield new_text
|
|
|
|
|
|
def mkssml(
|
|
text: Union[str, bytes], voice: str, rate: str, volume: str, pitch: str
|
|
) -> str:
|
|
"""
|
|
Creates a SSML string from the given parameters.
|
|
|
|
Returns:
|
|
str: The SSML string.
|
|
"""
|
|
if isinstance(text, bytes):
|
|
text = text.decode("utf-8")
|
|
|
|
ssml = (
|
|
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>"
|
|
f"<voice name='{voice}'><prosody pitch='{pitch}' rate='{rate}' volume='{volume}'>"
|
|
f"{text}</prosody></voice></speak>"
|
|
)
|
|
return ssml
|
|
|
|
|
|
def date_to_string() -> str:
|
|
"""
|
|
Return Javascript-style date string.
|
|
|
|
Returns:
|
|
str: Javascript-style date string.
|
|
"""
|
|
# %Z is not what we want, but it's the only way to get the timezone
|
|
# without having to use a library. We'll just use UTC and hope for the best.
|
|
# For example, right now %Z would return EEST when we need it to return
|
|
# Eastern European Summer Time.
|
|
return time.strftime(
|
|
"%a %b %d %Y %H:%M:%S GMT+0000 (Coordinated Universal Time)", time.gmtime()
|
|
)
|
|
|
|
|
|
def ssml_headers_plus_data(request_id: str, timestamp: str, ssml: str) -> str:
|
|
"""
|
|
Returns the headers and data to be used in the request.
|
|
|
|
Returns:
|
|
str: The headers and data to be used in the request.
|
|
"""
|
|
|
|
return (
|
|
f"X-RequestId:{request_id}\r\n"
|
|
"Content-Type:application/ssml+xml\r\n"
|
|
f"X-Timestamp:{timestamp}Z\r\n" # This is not a mistake, Microsoft Edge bug.
|
|
"Path:ssml\r\n\r\n"
|
|
f"{ssml}"
|
|
)
|
|
|
|
|
|
def calc_max_mesg_size(voice: str, rate: str, volume: str, pitch: str) -> int:
|
|
"""Calculates the maximum message size for the given voice, rate, and volume.
|
|
|
|
Returns:
|
|
int: The maximum message size.
|
|
"""
|
|
websocket_max_size: int = 2**16
|
|
overhead_per_message: int = (
|
|
len(
|
|
ssml_headers_plus_data(
|
|
connect_id(),
|
|
date_to_string(),
|
|
mkssml("", voice, rate, volume, pitch),
|
|
)
|
|
)
|
|
+ 50 # margin of error
|
|
)
|
|
return websocket_max_size - overhead_per_message
|
|
|
|
|
|
class Communicate:
|
|
"""
|
|
Class for communicating with the service.
|
|
"""
|
|
|
|
@staticmethod
|
|
def validate_string_param(param_name: str, param_value: str, pattern: str) -> str:
|
|
"""
|
|
Validates the given string parameter based on type and pattern.
|
|
|
|
Args:
|
|
param_name (str): The name of the parameter.
|
|
param_value (str): The value of the parameter.
|
|
pattern (str): The pattern to validate the parameter against.
|
|
|
|
Returns:
|
|
str: The validated parameter.
|
|
"""
|
|
if not isinstance(param_value, str):
|
|
raise TypeError(f"{param_name} must be str")
|
|
if re.match(pattern, param_value) is None:
|
|
raise ValueError(f"Invalid {param_name} '{param_value}'.")
|
|
return param_value
|
|
|
|
def __init__(
|
|
self,
|
|
text: str,
|
|
voice: str = "Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)",
|
|
*,
|
|
rate: str = "+0%",
|
|
volume: str = "+0%",
|
|
pitch: str = "+0Hz",
|
|
proxy: Optional[str] = None,
|
|
receive_timeout: int = 5,
|
|
):
|
|
"""
|
|
Initializes the Communicate class.
|
|
|
|
Raises:
|
|
ValueError: If the voice is not valid.
|
|
"""
|
|
if not isinstance(text, str):
|
|
raise TypeError("text must be str")
|
|
self.text: str = text
|
|
|
|
# Possible values for voice are:
|
|
# - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural)
|
|
# - cy-GB-NiaNeural
|
|
# - fil-PH-AngeloNeural
|
|
# Always send the first variant as that is what Microsoft Edge does.
|
|
if not isinstance(voice, str):
|
|
raise TypeError("voice must be str")
|
|
self.voice: str = voice
|
|
match = re.match(r"^([a-z]{2,})-([A-Z]{2,})-(.+Neural)$", voice)
|
|
if match is not None:
|
|
lang = match.group(1)
|
|
region = match.group(2)
|
|
name = match.group(3)
|
|
if name.find("-") != -1:
|
|
region = region + "-" + name[: name.find("-")]
|
|
name = name[name.find("-") + 1 :]
|
|
self.voice = (
|
|
"Microsoft Server Speech Text to Speech Voice"
|
|
+ f" ({lang}-{region}, {name})"
|
|
)
|
|
|
|
self.voice = self.validate_string_param(
|
|
"voice",
|
|
self.voice,
|
|
r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$",
|
|
)
|
|
self.rate = self.validate_string_param("rate", rate, r"^[+-]\d+%$")
|
|
self.volume = self.validate_string_param("volume", volume, r"^[+-]\d+%$")
|
|
self.pitch = self.validate_string_param("pitch", pitch, r"^[+-]\d+Hz$")
|
|
|
|
if proxy is not None and not isinstance(proxy, str):
|
|
raise TypeError("proxy must be str")
|
|
self.proxy: Optional[str] = proxy
|
|
|
|
if not isinstance(receive_timeout, int):
|
|
raise TypeError("receive_timeout must be int")
|
|
self.receive_timeout: int = receive_timeout
|
|
|
|
async def stream(self) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""Streams audio and metadata from the service."""
|
|
|
|
async def send_command_request() -> None:
|
|
"""Sends the request to the service."""
|
|
|
|
# Prepare the request to be sent to the service.
|
|
#
|
|
# Note sentenceBoundaryEnabled and wordBoundaryEnabled are actually supposed
|
|
# to be booleans, but Edge Browser seems to send them as strings.
|
|
#
|
|
# This is a bug in Edge as Azure Cognitive Services actually sends them as
|
|
# bool and not string. For now I will send them as bool unless it causes
|
|
# any problems.
|
|
#
|
|
# Also pay close attention to double { } in request (escape for f-string).
|
|
await websocket.send_str(
|
|
f"X-Timestamp:{date_to_string()}\r\n"
|
|
"Content-Type:application/json; charset=utf-8\r\n"
|
|
"Path:speech.config\r\n\r\n"
|
|
'{"context":{"synthesis":{"audio":{"metadataoptions":{'
|
|
'"sentenceBoundaryEnabled":false,"wordBoundaryEnabled":true},'
|
|
'"outputFormat":"audio-24khz-48kbitrate-mono-mp3"'
|
|
"}}}}\r\n"
|
|
)
|
|
|
|
async def send_ssml_request() -> bool:
|
|
"""Sends the SSML request to the service."""
|
|
|
|
# Get the next string from the generator.
|
|
text = next(texts, None)
|
|
|
|
# If there are no more strings, return False.
|
|
if text is None:
|
|
return False
|
|
|
|
# Send the request to the service and return True.
|
|
await websocket.send_str(
|
|
ssml_headers_plus_data(
|
|
connect_id(),
|
|
date_to_string(),
|
|
mkssml(text, self.voice, self.rate, self.volume, self.pitch),
|
|
)
|
|
)
|
|
return True
|
|
|
|
def parse_metadata() -> Dict[str, Any]:
|
|
for meta_obj in json.loads(data)["Metadata"]:
|
|
meta_type = meta_obj["Type"]
|
|
if meta_type == "WordBoundary":
|
|
current_offset = meta_obj["Data"]["Offset"] + offset_compensation
|
|
current_duration = meta_obj["Data"]["Duration"]
|
|
return {
|
|
"type": meta_type,
|
|
"offset": current_offset,
|
|
"duration": current_duration,
|
|
"text": meta_obj["Data"]["text"]["Text"],
|
|
}
|
|
if meta_type in ("SessionEnd",):
|
|
continue
|
|
raise UnknownResponse(f"Unknown metadata type: {meta_type}")
|
|
raise UnexpectedResponse("No WordBoundary metadata found")
|
|
|
|
# Split the text into multiple strings if it is too long for the service.
|
|
texts = split_text_by_byte_length(
|
|
escape(remove_incompatible_characters(self.text)),
|
|
calc_max_mesg_size(self.voice, self.rate, self.volume, self.pitch),
|
|
)
|
|
|
|
# Keep track of last duration + offset to calculate the offset
|
|
# upon word split.
|
|
last_duration_offset = 0
|
|
|
|
# Current offset compensations.
|
|
offset_compensation = 0
|
|
|
|
# Create a new connection to the service.
|
|
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
|
|
async with aiohttp.ClientSession(
|
|
trust_env=True,
|
|
) as session, session.ws_connect(
|
|
f"{WSS_URL}&ConnectionId={connect_id()}",
|
|
compress=15,
|
|
proxy=self.proxy,
|
|
receive_timeout=self.receive_timeout,
|
|
headers={
|
|
"Pragma": "no-cache",
|
|
"Cache-Control": "no-cache",
|
|
"Origin": "chrome-extension://jdiccldimpdaibmpdkjnbmckianbfold",
|
|
"Accept-Encoding": "gzip, deflate, br",
|
|
"Accept-Language": "en-US,en;q=0.9",
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
|
" (KHTML, like Gecko) Chrome/91.0.4472.77 Safari/537.36 Edg/91.0.864.41",
|
|
},
|
|
ssl=ssl_ctx,
|
|
) as websocket:
|
|
# audio_was_received indicates whether we have received audio data
|
|
# from the websocket. This is so we can raise an exception if we
|
|
# don't receive any audio data.
|
|
audio_was_received = False
|
|
|
|
# Send the request to the service.
|
|
await send_command_request()
|
|
|
|
# Send the SSML request to the service.
|
|
await send_ssml_request()
|
|
|
|
async for received in websocket:
|
|
if received.type == aiohttp.WSMsgType.TEXT:
|
|
parameters, data = get_headers_and_data(received.data)
|
|
path = parameters.get(b"Path")
|
|
if path == b"audio.metadata":
|
|
# Parse the metadata and yield it.
|
|
parsed_metadata = parse_metadata()
|
|
yield parsed_metadata
|
|
|
|
# Update the last duration offset for use by the next SSML request.
|
|
last_duration_offset = (
|
|
parsed_metadata["offset"] + parsed_metadata["duration"]
|
|
)
|
|
elif path == b"turn.end":
|
|
# Update the offset compensation for the next SSML request.
|
|
offset_compensation = last_duration_offset
|
|
|
|
# Use average padding typically added by the service
|
|
# to the end of the audio data. This seems to work pretty
|
|
# well for now, but we might ultimately need to use a
|
|
# more sophisticated method like using ffmpeg to get
|
|
# the actual duration of the audio data.
|
|
offset_compensation += 8_750_000
|
|
|
|
# Send the next SSML request to the service.
|
|
if not await send_ssml_request():
|
|
break
|
|
elif path in (b"response", b"turn.start"):
|
|
pass
|
|
else:
|
|
raise UnknownResponse(
|
|
"The response from the service is not recognized.\n"
|
|
+ received.data
|
|
)
|
|
elif received.type == aiohttp.WSMsgType.BINARY:
|
|
if len(received.data) < 2:
|
|
raise UnexpectedResponse(
|
|
"We received a binary message, but it is missing the header length."
|
|
)
|
|
|
|
header_length = int.from_bytes(received.data[:2], "big")
|
|
if len(received.data) < header_length + 2:
|
|
raise UnexpectedResponse(
|
|
"We received a binary message, but it is missing the audio data."
|
|
)
|
|
|
|
audio_was_received = True
|
|
yield {
|
|
"type": "audio",
|
|
"data": received.data[header_length + 2 :],
|
|
}
|
|
elif received.type == aiohttp.WSMsgType.ERROR:
|
|
raise WebSocketError(
|
|
received.data if received.data else "Unknown error"
|
|
)
|
|
|
|
if not audio_was_received:
|
|
raise NoAudioReceived(
|
|
"No audio was received. Please verify that your parameters are correct."
|
|
)
|
|
|
|
async def save(
|
|
self,
|
|
audio_fname: Union[str, bytes],
|
|
metadata_fname: Optional[Union[str, bytes]] = None,
|
|
) -> None:
|
|
"""
|
|
Save the audio and metadata to the specified files.
|
|
"""
|
|
metadata: Union[TextIOWrapper, ContextManager[None]] = (
|
|
open(metadata_fname, "w", encoding="utf-8")
|
|
if metadata_fname is not None
|
|
else nullcontext()
|
|
)
|
|
with metadata, open(audio_fname, "wb") as audio:
|
|
async for message in self.stream():
|
|
if message["type"] == "audio":
|
|
audio.write(message["data"])
|
|
elif (
|
|
isinstance(metadata, TextIOWrapper)
|
|
and message["type"] == "WordBoundary"
|
|
):
|
|
json.dump(message, metadata)
|
|
metadata.write("\n")
|
|
|
|
def stream_sync(self) -> Generator[Dict[str, Any], None, None]:
|
|
"""Synchronous interface for async stream method"""
|
|
|
|
def fetch_async_items(queue: Queue) -> None: # type: ignore
|
|
async def get_items() -> None:
|
|
async for item in self.stream():
|
|
queue.put(item)
|
|
queue.put(None)
|
|
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_until_complete(get_items())
|
|
loop.close()
|
|
|
|
queue: Queue = Queue() # type: ignore
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
executor.submit(fetch_async_items, queue)
|
|
|
|
while True:
|
|
item = queue.get()
|
|
if item is None:
|
|
break
|
|
yield item
|
|
|
|
def save_sync(
|
|
self,
|
|
audio_fname: Union[str, bytes],
|
|
metadata_fname: Optional[Union[str, bytes]] = None,
|
|
) -> None:
|
|
"""Synchronous interface for async save method."""
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
future = executor.submit(
|
|
asyncio.run, self.save(audio_fname, metadata_fname)
|
|
)
|
|
future.result()
|