Slight performance improvements and bug fixes

* also add new functionality for edge-playback to keep temp files
* and bump version to 6.0.9
This commit is contained in:
rany2
2023-01-09 17:47:04 +02:00
parent d4da421ef6
commit bd9cc2bd2d
3 changed files with 62 additions and 55 deletions

View File

@@ -34,7 +34,7 @@ from edge_tts.exceptions import (
from .constants import WSS_URL
def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[str, str], bytes]:
def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[bytes, bytes], bytes]:
"""
Returns the headers and data from the given data.
@@ -50,14 +50,11 @@ def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[str, str], bytes
raise TypeError("data must be str or bytes")
headers = {}
for line in data.split(b"\r\n\r\n")[0].split(b"\r\n"):
line_split = line.split(b":")
key, value = line_split[0], b":".join(line_split[1:])
if value.startswith(b" "):
value = value[1:]
headers[key.decode("utf-8")] = value.decode("utf-8")
for line in data[: data.find(b"\r\n\r\n")].split(b"\r\n"):
key, value = line.split(b":", 1)
headers[key] = value
return headers, b"\r\n\r\n".join(data.split(b"\r\n\r\n")[1:])
return headers, data[data.find(b"\r\n\r\n") + 4 :]
def remove_incompatible_characters(string: Union[str, bytes]) -> str:
@@ -98,55 +95,59 @@ def connect_id() -> str:
return str(uuid.uuid4()).replace("-", "")
def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]:
"""
Iterates over bytes object
Args:
my_bytes: Bytes object to iterate over
Yields:
the individual bytes
"""
for i in range(len(my_bytes)):
yield my_bytes[i : i + 1]
def split_text_by_byte_length(text: Union[str, bytes], byte_length: int) -> List[bytes]:
def split_text_by_byte_length(
text: Union[str, bytes], byte_length: int
) -> Generator[bytes, None, None]:
"""
Splits a string into a list of strings of a given byte length
while attempting to keep words together.
while attempting to keep words together. This function assumes
text will be inside of an XML tag.
Args:
text (str or bytes): The string to be split.
byte_length (int): The maximum byte length of each string in the list.
Returns:
list: A list of bytes of the given byte length.
Yield:
bytes: The next string in the list.
"""
if isinstance(text, str):
text = text.encode("utf-8")
if not isinstance(text, bytes):
raise TypeError("text must be str or bytes")
words = []
if byte_length <= 0:
raise ValueError("byte_length must be greater than 0")
while len(text) > byte_length:
# Find the last space in the string
last_space = text.rfind(b" ", 0, byte_length)
if last_space == -1:
# No space found, just split at the byte length
words.append(text[:byte_length])
text = text[byte_length:]
else:
# Split at the last space
words.append(text[:last_space])
text = text[last_space:]
words.append(text)
split_at = text.rfind(b" ", 0, byte_length)
# Remove empty strings from the list
words = [word for word in words if word]
# Return the list
return words
# If no space found, split_at is byte_length
split_at = split_at if split_at != -1 else byte_length
# Verify all & are terminated with a ;
while b"&" in text[:split_at]:
ampersand_index = text.rindex(b"&", 0, split_at)
if text.find(b";", ampersand_index, split_at) != -1:
break
split_at = ampersand_index - 1
if split_at < 0:
raise ValueError("Maximum byte length is too small or invalid text")
if split_at == 0:
break
# Append the string to the list
new_text = text[:split_at].strip()
if new_text:
yield new_text
if split_at == 0:
split_at = 1
text = text[split_at:]
new_text = text.strip()
if new_text:
yield new_text
def mkssml(text: Union[str, bytes], voice: str, rate: str, volume: str) -> str:
@@ -352,15 +353,14 @@ class Communicate:
async for received in websocket:
if received.type == aiohttp.WSMsgType.TEXT:
parameters, data = get_headers_and_data(received.data)
if parameters.get("Path") == "turn.start":
path = parameters.get(b"Path")
if path == b"turn.start":
download_audio = True
elif parameters.get("Path") == "turn.end":
elif path == b"turn.end":
download_audio = False
break # End of audio data
elif parameters.get("Path") == "audio.metadata":
meta = json.loads(data)
for i in range(len(meta["Metadata"])):
meta_obj = meta["Metadata"][i]
elif path == b"audio.metadata":
for meta_obj in json.loads(data)["Metadata"]:
meta_type = meta_obj["Type"]
if meta_type == "WordBoundary":
yield {
@@ -375,7 +375,7 @@ class Communicate:
raise UnknownResponse(
f"Unknown metadata type: {meta_type}"
)
elif parameters.get("Path") == "response":
elif path == b"response":
pass
else:
raise UnknownResponse(
@@ -390,13 +390,15 @@ class Communicate:
yield {
"type": "audio",
"data": b"Path:audio\r\n".join(
received.data.split(b"Path:audio\r\n")[1:]
),
"data": received.data[
received.data.find(b"Path:audio\r\n") + 12 :
],
}
audio_was_received = True
elif received.type == aiohttp.WSMsgType.ERROR:
raise WebSocketError(received.data)
raise WebSocketError(
received.data if received.data else "Unknown error"
)
if not audio_was_received:
raise NoAudioReceived(