From bd9cc2bd2d922458e9448b9730d5d2e536393e28 Mon Sep 17 00:00:00 2001 From: rany2 Date: Mon, 9 Jan 2023 17:47:04 +0200 Subject: [PATCH] Slight performance improvements and bug fixes * also add new functionality for edge-playback to keep temp files * and bump version to 6.0.9 --- setup.cfg | 2 +- src/edge_playback/__main__.py | 9 ++- src/edge_tts/communicate.py | 106 +++++++++++++++++----------------- 3 files changed, 62 insertions(+), 55 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7cf9030..1e85359 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = edge-tts -version = 6.0.8 +version = 6.0.9 author = rany author_email = ranygh@riseup.net description = Microsoft Edge's TTS diff --git a/src/edge_playback/__main__.py b/src/edge_playback/__main__.py index 0c4e46d..da84fb0 100644 --- a/src/edge_playback/__main__.py +++ b/src/edge_playback/__main__.py @@ -4,6 +4,7 @@ Playback TTS with subtitles using edge-tts and mpv. """ +import os import subprocess import sys import tempfile @@ -22,9 +23,10 @@ def _main() -> None: print("Please install the missing dependencies.", file=sys.stderr) sys.exit(1) + keep = os.environ.get("EDGE_PLAYBACK_KEEP_TEMP") is not None with tempfile.NamedTemporaryFile( - suffix=".mp3", delete=False - ) as media, tempfile.NamedTemporaryFile(suffix=".vtt", delete=False) as subtitle: + suffix=".mp3", delete=not keep + ) as media, tempfile.NamedTemporaryFile(suffix=".vtt", delete=not keep) as subtitle: media.close() subtitle.close() @@ -49,6 +51,9 @@ def _main() -> None: ) as process: process.communicate() + if keep: + print(f"\nKeeping temporary files: {media.name} and {subtitle.name}") + if __name__ == "__main__": _main() diff --git a/src/edge_tts/communicate.py b/src/edge_tts/communicate.py index 9643271..f518c8e 100644 --- a/src/edge_tts/communicate.py +++ b/src/edge_tts/communicate.py @@ -34,7 +34,7 @@ from edge_tts.exceptions import ( from .constants import WSS_URL -def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[str, str], bytes]: +def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[bytes, bytes], bytes]: """ Returns the headers and data from the given data. @@ -50,14 +50,11 @@ def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[str, str], bytes raise TypeError("data must be str or bytes") headers = {} - for line in data.split(b"\r\n\r\n")[0].split(b"\r\n"): - line_split = line.split(b":") - key, value = line_split[0], b":".join(line_split[1:]) - if value.startswith(b" "): - value = value[1:] - headers[key.decode("utf-8")] = value.decode("utf-8") + for line in data[: data.find(b"\r\n\r\n")].split(b"\r\n"): + key, value = line.split(b":", 1) + headers[key] = value - return headers, b"\r\n\r\n".join(data.split(b"\r\n\r\n")[1:]) + return headers, data[data.find(b"\r\n\r\n") + 4 :] def remove_incompatible_characters(string: Union[str, bytes]) -> str: @@ -98,55 +95,59 @@ def connect_id() -> str: return str(uuid.uuid4()).replace("-", "") -def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]: - """ - Iterates over bytes object - - Args: - my_bytes: Bytes object to iterate over - - Yields: - the individual bytes - """ - for i in range(len(my_bytes)): - yield my_bytes[i : i + 1] - - -def split_text_by_byte_length(text: Union[str, bytes], byte_length: int) -> List[bytes]: +def split_text_by_byte_length( + text: Union[str, bytes], byte_length: int +) -> Generator[bytes, None, None]: """ Splits a string into a list of strings of a given byte length - while attempting to keep words together. + while attempting to keep words together. This function assumes + text will be inside of an XML tag. Args: text (str or bytes): The string to be split. byte_length (int): The maximum byte length of each string in the list. - Returns: - list: A list of bytes of the given byte length. + Yield: + bytes: The next string in the list. """ if isinstance(text, str): text = text.encode("utf-8") if not isinstance(text, bytes): raise TypeError("text must be str or bytes") - words = [] + if byte_length <= 0: + raise ValueError("byte_length must be greater than 0") + while len(text) > byte_length: # Find the last space in the string - last_space = text.rfind(b" ", 0, byte_length) - if last_space == -1: - # No space found, just split at the byte length - words.append(text[:byte_length]) - text = text[byte_length:] - else: - # Split at the last space - words.append(text[:last_space]) - text = text[last_space:] - words.append(text) + split_at = text.rfind(b" ", 0, byte_length) - # Remove empty strings from the list - words = [word for word in words if word] - # Return the list - return words + # If no space found, split_at is byte_length + split_at = split_at if split_at != -1 else byte_length + + # Verify all & are terminated with a ; + while b"&" in text[:split_at]: + ampersand_index = text.rindex(b"&", 0, split_at) + if text.find(b";", ampersand_index, split_at) != -1: + break + + split_at = ampersand_index - 1 + if split_at < 0: + raise ValueError("Maximum byte length is too small or invalid text") + if split_at == 0: + break + + # Append the string to the list + new_text = text[:split_at].strip() + if new_text: + yield new_text + if split_at == 0: + split_at = 1 + text = text[split_at:] + + new_text = text.strip() + if new_text: + yield new_text def mkssml(text: Union[str, bytes], voice: str, rate: str, volume: str) -> str: @@ -352,15 +353,14 @@ class Communicate: async for received in websocket: if received.type == aiohttp.WSMsgType.TEXT: parameters, data = get_headers_and_data(received.data) - if parameters.get("Path") == "turn.start": + path = parameters.get(b"Path") + if path == b"turn.start": download_audio = True - elif parameters.get("Path") == "turn.end": + elif path == b"turn.end": download_audio = False break # End of audio data - elif parameters.get("Path") == "audio.metadata": - meta = json.loads(data) - for i in range(len(meta["Metadata"])): - meta_obj = meta["Metadata"][i] + elif path == b"audio.metadata": + for meta_obj in json.loads(data)["Metadata"]: meta_type = meta_obj["Type"] if meta_type == "WordBoundary": yield { @@ -375,7 +375,7 @@ class Communicate: raise UnknownResponse( f"Unknown metadata type: {meta_type}" ) - elif parameters.get("Path") == "response": + elif path == b"response": pass else: raise UnknownResponse( @@ -390,13 +390,15 @@ class Communicate: yield { "type": "audio", - "data": b"Path:audio\r\n".join( - received.data.split(b"Path:audio\r\n")[1:] - ), + "data": received.data[ + received.data.find(b"Path:audio\r\n") + 12 : + ], } audio_was_received = True elif received.type == aiohttp.WSMsgType.ERROR: - raise WebSocketError(received.data) + raise WebSocketError( + received.data if received.data else "Unknown error" + ) if not audio_was_received: raise NoAudioReceived(