Slight performance improvements and bug fixes
* also add new functionality for edge-playback to keep temp files * and bump version to 6.0.9
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
name = edge-tts
|
name = edge-tts
|
||||||
version = 6.0.8
|
version = 6.0.9
|
||||||
author = rany
|
author = rany
|
||||||
author_email = ranygh@riseup.net
|
author_email = ranygh@riseup.net
|
||||||
description = Microsoft Edge's TTS
|
description = Microsoft Edge's TTS
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
Playback TTS with subtitles using edge-tts and mpv.
|
Playback TTS with subtitles using edge-tts and mpv.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -22,9 +23,10 @@ def _main() -> None:
|
|||||||
print("Please install the missing dependencies.", file=sys.stderr)
|
print("Please install the missing dependencies.", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
keep = os.environ.get("EDGE_PLAYBACK_KEEP_TEMP") is not None
|
||||||
with tempfile.NamedTemporaryFile(
|
with tempfile.NamedTemporaryFile(
|
||||||
suffix=".mp3", delete=False
|
suffix=".mp3", delete=not keep
|
||||||
) as media, tempfile.NamedTemporaryFile(suffix=".vtt", delete=False) as subtitle:
|
) as media, tempfile.NamedTemporaryFile(suffix=".vtt", delete=not keep) as subtitle:
|
||||||
media.close()
|
media.close()
|
||||||
subtitle.close()
|
subtitle.close()
|
||||||
|
|
||||||
@@ -49,6 +51,9 @@ def _main() -> None:
|
|||||||
) as process:
|
) as process:
|
||||||
process.communicate()
|
process.communicate()
|
||||||
|
|
||||||
|
if keep:
|
||||||
|
print(f"\nKeeping temporary files: {media.name} and {subtitle.name}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_main()
|
_main()
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from edge_tts.exceptions import (
|
|||||||
from .constants import WSS_URL
|
from .constants import WSS_URL
|
||||||
|
|
||||||
|
|
||||||
def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[str, str], bytes]:
|
def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[bytes, bytes], bytes]:
|
||||||
"""
|
"""
|
||||||
Returns the headers and data from the given data.
|
Returns the headers and data from the given data.
|
||||||
|
|
||||||
@@ -50,14 +50,11 @@ def get_headers_and_data(data: Union[str, bytes]) -> Tuple[Dict[str, str], bytes
|
|||||||
raise TypeError("data must be str or bytes")
|
raise TypeError("data must be str or bytes")
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
for line in data.split(b"\r\n\r\n")[0].split(b"\r\n"):
|
for line in data[: data.find(b"\r\n\r\n")].split(b"\r\n"):
|
||||||
line_split = line.split(b":")
|
key, value = line.split(b":", 1)
|
||||||
key, value = line_split[0], b":".join(line_split[1:])
|
headers[key] = value
|
||||||
if value.startswith(b" "):
|
|
||||||
value = value[1:]
|
|
||||||
headers[key.decode("utf-8")] = value.decode("utf-8")
|
|
||||||
|
|
||||||
return headers, b"\r\n\r\n".join(data.split(b"\r\n\r\n")[1:])
|
return headers, data[data.find(b"\r\n\r\n") + 4 :]
|
||||||
|
|
||||||
|
|
||||||
def remove_incompatible_characters(string: Union[str, bytes]) -> str:
|
def remove_incompatible_characters(string: Union[str, bytes]) -> str:
|
||||||
@@ -98,55 +95,59 @@ def connect_id() -> str:
|
|||||||
return str(uuid.uuid4()).replace("-", "")
|
return str(uuid.uuid4()).replace("-", "")
|
||||||
|
|
||||||
|
|
||||||
def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]:
|
def split_text_by_byte_length(
|
||||||
"""
|
text: Union[str, bytes], byte_length: int
|
||||||
Iterates over bytes object
|
) -> Generator[bytes, None, None]:
|
||||||
|
|
||||||
Args:
|
|
||||||
my_bytes: Bytes object to iterate over
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
the individual bytes
|
|
||||||
"""
|
|
||||||
for i in range(len(my_bytes)):
|
|
||||||
yield my_bytes[i : i + 1]
|
|
||||||
|
|
||||||
|
|
||||||
def split_text_by_byte_length(text: Union[str, bytes], byte_length: int) -> List[bytes]:
|
|
||||||
"""
|
"""
|
||||||
Splits a string into a list of strings of a given byte length
|
Splits a string into a list of strings of a given byte length
|
||||||
while attempting to keep words together.
|
while attempting to keep words together. This function assumes
|
||||||
|
text will be inside of an XML tag.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str or bytes): The string to be split.
|
text (str or bytes): The string to be split.
|
||||||
byte_length (int): The maximum byte length of each string in the list.
|
byte_length (int): The maximum byte length of each string in the list.
|
||||||
|
|
||||||
Returns:
|
Yield:
|
||||||
list: A list of bytes of the given byte length.
|
bytes: The next string in the list.
|
||||||
"""
|
"""
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = text.encode("utf-8")
|
text = text.encode("utf-8")
|
||||||
if not isinstance(text, bytes):
|
if not isinstance(text, bytes):
|
||||||
raise TypeError("text must be str or bytes")
|
raise TypeError("text must be str or bytes")
|
||||||
|
|
||||||
words = []
|
if byte_length <= 0:
|
||||||
|
raise ValueError("byte_length must be greater than 0")
|
||||||
|
|
||||||
while len(text) > byte_length:
|
while len(text) > byte_length:
|
||||||
# Find the last space in the string
|
# Find the last space in the string
|
||||||
last_space = text.rfind(b" ", 0, byte_length)
|
split_at = text.rfind(b" ", 0, byte_length)
|
||||||
if last_space == -1:
|
|
||||||
# No space found, just split at the byte length
|
|
||||||
words.append(text[:byte_length])
|
|
||||||
text = text[byte_length:]
|
|
||||||
else:
|
|
||||||
# Split at the last space
|
|
||||||
words.append(text[:last_space])
|
|
||||||
text = text[last_space:]
|
|
||||||
words.append(text)
|
|
||||||
|
|
||||||
# Remove empty strings from the list
|
# If no space found, split_at is byte_length
|
||||||
words = [word for word in words if word]
|
split_at = split_at if split_at != -1 else byte_length
|
||||||
# Return the list
|
|
||||||
return words
|
# Verify all & are terminated with a ;
|
||||||
|
while b"&" in text[:split_at]:
|
||||||
|
ampersand_index = text.rindex(b"&", 0, split_at)
|
||||||
|
if text.find(b";", ampersand_index, split_at) != -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
split_at = ampersand_index - 1
|
||||||
|
if split_at < 0:
|
||||||
|
raise ValueError("Maximum byte length is too small or invalid text")
|
||||||
|
if split_at == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Append the string to the list
|
||||||
|
new_text = text[:split_at].strip()
|
||||||
|
if new_text:
|
||||||
|
yield new_text
|
||||||
|
if split_at == 0:
|
||||||
|
split_at = 1
|
||||||
|
text = text[split_at:]
|
||||||
|
|
||||||
|
new_text = text.strip()
|
||||||
|
if new_text:
|
||||||
|
yield new_text
|
||||||
|
|
||||||
|
|
||||||
def mkssml(text: Union[str, bytes], voice: str, rate: str, volume: str) -> str:
|
def mkssml(text: Union[str, bytes], voice: str, rate: str, volume: str) -> str:
|
||||||
@@ -352,15 +353,14 @@ class Communicate:
|
|||||||
async for received in websocket:
|
async for received in websocket:
|
||||||
if received.type == aiohttp.WSMsgType.TEXT:
|
if received.type == aiohttp.WSMsgType.TEXT:
|
||||||
parameters, data = get_headers_and_data(received.data)
|
parameters, data = get_headers_and_data(received.data)
|
||||||
if parameters.get("Path") == "turn.start":
|
path = parameters.get(b"Path")
|
||||||
|
if path == b"turn.start":
|
||||||
download_audio = True
|
download_audio = True
|
||||||
elif parameters.get("Path") == "turn.end":
|
elif path == b"turn.end":
|
||||||
download_audio = False
|
download_audio = False
|
||||||
break # End of audio data
|
break # End of audio data
|
||||||
elif parameters.get("Path") == "audio.metadata":
|
elif path == b"audio.metadata":
|
||||||
meta = json.loads(data)
|
for meta_obj in json.loads(data)["Metadata"]:
|
||||||
for i in range(len(meta["Metadata"])):
|
|
||||||
meta_obj = meta["Metadata"][i]
|
|
||||||
meta_type = meta_obj["Type"]
|
meta_type = meta_obj["Type"]
|
||||||
if meta_type == "WordBoundary":
|
if meta_type == "WordBoundary":
|
||||||
yield {
|
yield {
|
||||||
@@ -375,7 +375,7 @@ class Communicate:
|
|||||||
raise UnknownResponse(
|
raise UnknownResponse(
|
||||||
f"Unknown metadata type: {meta_type}"
|
f"Unknown metadata type: {meta_type}"
|
||||||
)
|
)
|
||||||
elif parameters.get("Path") == "response":
|
elif path == b"response":
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise UnknownResponse(
|
raise UnknownResponse(
|
||||||
@@ -390,13 +390,15 @@ class Communicate:
|
|||||||
|
|
||||||
yield {
|
yield {
|
||||||
"type": "audio",
|
"type": "audio",
|
||||||
"data": b"Path:audio\r\n".join(
|
"data": received.data[
|
||||||
received.data.split(b"Path:audio\r\n")[1:]
|
received.data.find(b"Path:audio\r\n") + 12 :
|
||||||
),
|
],
|
||||||
}
|
}
|
||||||
audio_was_received = True
|
audio_was_received = True
|
||||||
elif received.type == aiohttp.WSMsgType.ERROR:
|
elif received.type == aiohttp.WSMsgType.ERROR:
|
||||||
raise WebSocketError(received.data)
|
raise WebSocketError(
|
||||||
|
received.data if received.data else "Unknown error"
|
||||||
|
)
|
||||||
|
|
||||||
if not audio_was_received:
|
if not audio_was_received:
|
||||||
raise NoAudioReceived(
|
raise NoAudioReceived(
|
||||||
|
|||||||
Reference in New Issue
Block a user