Simplify edge_tts library usage

This commit is contained in:
rany2
2023-01-04 23:45:22 +02:00
parent 142b4f6457
commit 23370b4c27
5 changed files with 263 additions and 221 deletions

View File

@@ -12,44 +12,51 @@ from shutil import which
def main(): def main():
""" depcheck_failed = False
Main function. if not which("mpv"):
""" print("mpv is not installed.", file=sys.stderr)
if which("mpv") and which("edge-tts"): depcheck_failed = True
if not which("edge-tts"):
print("edge-tts is not installed.", file=sys.stderr)
depcheck_failed = True
if depcheck_failed:
print("Please install the missing dependencies.", file=sys.stderr)
sys.exit(1)
media = None
subtitle = None
try:
media = tempfile.NamedTemporaryFile(delete=False) media = tempfile.NamedTemporaryFile(delete=False)
media.close()
subtitle = tempfile.NamedTemporaryFile(delete=False) subtitle = tempfile.NamedTemporaryFile(delete=False)
try: subtitle.close()
media.close()
subtitle.close()
print() print(f"Media file: {media.name}")
print(f"Media file: {media.name}") print(f"Subtitle file: {subtitle.name}\n")
print(f"Subtitle file: {subtitle.name}\n") with subprocess.Popen(
with subprocess.Popen( [
[ "edge-tts",
"edge-tts", f"--write-media={media.name}",
"--boundary-type=1", f"--write-subtitles={subtitle.name}",
f"--write-media={media.name}", ]
f"--write-subtitles={subtitle.name}", + sys.argv[1:]
] ) as process:
+ sys.argv[1:] process.communicate()
) as process:
process.communicate()
with subprocess.Popen( with subprocess.Popen(
[ [
"mpv", "mpv",
"--keep-open=yes", f"--sub-file={subtitle.name}",
f"--sub-file={subtitle.name}", media.name,
media.name, ]
] ) as process:
) as process: process.communicate()
process.communicate() finally:
finally: if media is not None:
os.unlink(media.name) os.unlink(media.name)
if subtitle is not None:
os.unlink(subtitle.name) os.unlink(subtitle.name)
else:
print("This script requires mpv and edge-tts.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -4,16 +4,20 @@ Communicate package.
import json import json
import re
import time import time
import uuid import uuid
from typing import Dict, Generator, List, Optional
from xml.sax.saxutils import escape from xml.sax.saxutils import escape
import aiohttp import aiohttp
from edge_tts.exceptions import *
from .constants import WSS_URL from .constants import WSS_URL
def get_headers_and_data(data): def get_headers_and_data(data: str | bytes) -> tuple[Dict[str, str], bytes]:
""" """
Returns the headers and data from the given data. Returns the headers and data from the given data.
@@ -25,6 +29,8 @@ def get_headers_and_data(data):
""" """
if isinstance(data, str): if isinstance(data, str):
data = data.encode("utf-8") data = data.encode("utf-8")
if not isinstance(data, bytes):
raise TypeError("data must be str or bytes")
headers = {} headers = {}
for line in data.split(b"\r\n\r\n")[0].split(b"\r\n"): for line in data.split(b"\r\n\r\n")[0].split(b"\r\n"):
@@ -37,7 +43,7 @@ def get_headers_and_data(data):
return headers, b"\r\n\r\n".join(data.split(b"\r\n\r\n")[1:]) return headers, b"\r\n\r\n".join(data.split(b"\r\n\r\n")[1:])
def remove_incompatible_characters(string): def remove_incompatible_characters(string: str | bytes) -> str:
""" """
The service does not support a couple character ranges. The service does not support a couple character ranges.
Most important being the vertical tab character which is Most important being the vertical tab character which is
@@ -52,31 +58,30 @@ def remove_incompatible_characters(string):
""" """
if isinstance(string, bytes): if isinstance(string, bytes):
string = string.decode("utf-8") string = string.decode("utf-8")
if not isinstance(string, str):
raise TypeError("string must be str or bytes")
string = list(string) chars: List[str] = list(string)
for idx, char in enumerate(string): for idx, char in enumerate(chars):
code = ord(char) code: int = ord(char)
if (0 <= code <= 8) or (11 <= code <= 12) or (14 <= code <= 31): if (0 <= code <= 8) or (11 <= code <= 12) or (14 <= code <= 31):
string[idx] = " " chars[idx] = " "
return "".join(string) return "".join(chars)
def connect_id(): def connect_id() -> str:
""" """
Returns a UUID without dashes. Returns a UUID without dashes.
Args:
None
Returns: Returns:
str: A UUID without dashes. str: A UUID without dashes.
""" """
return str(uuid.uuid4()).replace("-", "") return str(uuid.uuid4()).replace("-", "")
def iter_bytes(my_bytes): def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]:
""" """
Iterates over bytes object Iterates over bytes object
@@ -90,20 +95,22 @@ def iter_bytes(my_bytes):
yield my_bytes[i : i + 1] yield my_bytes[i : i + 1]
def split_text_by_byte_length(text, byte_length): def split_text_by_byte_length(text: bytes, byte_length: int) -> List[bytes]:
""" """
Splits a string into a list of strings of a given byte length Splits a string into a list of strings of a given byte length
while attempting to keep words together. while attempting to keep words together.
Args: Args:
text (byte): The string to be split. text (str or bytes): The string to be split.
byte_length (int): The byte length of each string in the list. byte_length (int): The maximum byte length of each string in the list.
Returns: Returns:
list: A list of strings of the given byte length. list: A list of bytes of the given byte length.
""" """
if isinstance(text, str): if isinstance(text, str):
text = text.encode("utf-8") text = text.encode("utf-8")
if not isinstance(text, bytes):
raise TypeError("text must be str or bytes")
words = [] words = []
while len(text) > byte_length: while len(text) > byte_length:
@@ -125,17 +132,10 @@ def split_text_by_byte_length(text, byte_length):
return words return words
def mkssml(text, voice, pitch, rate, volume): def mkssml(text: str | bytes, voice: str, pitch: str, rate: str, volume: str) -> str:
""" """
Creates a SSML string from the given parameters. Creates a SSML string from the given parameters.
Args:
text (str): The text to be spoken.
voice (str): The voice to be used.
pitch (str): The pitch to be used.
rate (str): The rate to be used.
volume (str): The volume to be used.
Returns: Returns:
str: The SSML string. str: The SSML string.
""" """
@@ -154,9 +154,6 @@ def date_to_string():
""" """
Return Javascript-style date string. Return Javascript-style date string.
Args:
None
Returns: Returns:
str: Javascript-style date string. str: Javascript-style date string.
""" """
@@ -171,15 +168,10 @@ def date_to_string():
) )
def ssml_headers_plus_data(request_id, timestamp, ssml): def ssml_headers_plus_data(request_id: str, timestamp: str, ssml: str) -> str:
""" """
Returns the headers and data to be used in the request. Returns the headers and data to be used in the request.
Args:
request_id (str): The request ID.
timestamp (str): The timestamp.
ssml (str): The SSML string.
Returns: Returns:
str: The headers and data to be used in the request. str: The headers and data to be used in the request.
""" """
@@ -198,73 +190,86 @@ class Communicate:
Class for communicating with the service. Class for communicating with the service.
""" """
def __init__(self): def __init__(
"""
Initializes the Communicate class.
"""
self.date = date_to_string()
async def run(
self, self,
messages, text: str | List[str],
boundary_type=0, voice: str = "Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)",
codec="audio-24khz-48kbitrate-mono-mp3", *,
voice="Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", pitch: str = "+0Hz",
pitch="+0Hz", rate: str = "+0%",
rate="+0%", volume: str = "+0%",
volume="+0%", proxy: Optional[str] = None,
proxy=None,
): ):
""" """
Runs the Communicate class. Initializes the Communicate class.
Args: Raises:
messages (str or list): A list of SSML strings or a single text. ValueError: If the voice is not valid.
boundery_type (int): The type of boundary to use. 0 for none, 1 for word_boundary, 2 for sentence_boundary.
codec (str): The codec to use.
voice (str): The voice to use.
pitch (str): The pitch to use.
rate (str): The rate to use.
volume (str): The volume to use.
Yields:
tuple: The subtitle offset, subtitle, and audio data.
""" """
self.text = text
word_boundary = False self.boundary_type = 1
self.codec = "audio-24khz-48kbitrate-mono-mp3"
if boundary_type > 0: self.voice = voice
word_boundary = True # Possible values for voice are:
if boundary_type > 1: # - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural)
raise ValueError( # - cy-GB-NiaNeural
"Invalid boundary type. SentenceBoundary is no longer supported." # Always send the first variant as that is what Microsoft Edge does.
match = re.match(r"^([a-z]{2})-([A-Z]{2})-(.+Neural)$", voice)
if match is not None:
self.voice = (
"Microsoft Server Speech Text to Speech Voice"
+ f" ({match.group(1)}-{match.group(2)}, {match.group(3)})"
) )
word_boundary = str(word_boundary).lower() if (
re.match(
r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$",
self.voice,
)
is None
):
raise ValueError(f"Invalid voice '{voice}'.")
websocket_max_size = 2 ** 16 if re.match(r"^[+-]\d+Hz$", pitch) is None:
raise ValueError(f"Invalid pitch '{pitch}'.")
self.pitch = pitch
if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", rate) is None:
raise ValueError(f"Invalid rate '{rate}'.")
self.rate = rate
if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", volume) is None:
raise ValueError(f"Invalid volume '{volume}'.")
self.volume = volume
self.proxy = proxy
async def stream(self):
"""Streams audio and metadata from the service."""
websocket_max_size = 2**16
overhead_per_message = ( overhead_per_message = (
len( len(
ssml_headers_plus_data( ssml_headers_plus_data(
connect_id(), self.date, mkssml("", voice, pitch, rate, volume) connect_id(),
date_to_string(),
mkssml("", self.voice, self.pitch, self.rate, self.volume),
) )
) )
+ 50 + 50 # margin of error
) # margin of error )
messages = split_text_by_byte_length( texts = split_text_by_byte_length(
escape(remove_incompatible_characters(messages)), escape(remove_incompatible_characters(self.text)),
websocket_max_size - overhead_per_message, websocket_max_size - overhead_per_message,
) )
# Variables for the loop
download = False
async with aiohttp.ClientSession(trust_env=True) as session: async with aiohttp.ClientSession(trust_env=True) as session:
async with session.ws_connect( async with session.ws_connect(
f"{WSS_URL}&ConnectionId={connect_id()}", f"{WSS_URL}&ConnectionId={connect_id()}",
compress=15, compress=15,
autoclose=True, autoclose=True,
autoping=True, autoping=True,
proxy=proxy, proxy=self.proxy,
headers={ headers={
"Pragma": "no-cache", "Pragma": "no-cache",
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
@@ -275,9 +280,19 @@ class Communicate:
" (KHTML, like Gecko) Chrome/91.0.4472.77 Safari/537.36 Edg/91.0.864.41", " (KHTML, like Gecko) Chrome/91.0.4472.77 Safari/537.36 Edg/91.0.864.41",
}, },
) as websocket: ) as websocket:
for message in messages: for text in texts:
# download indicates whether we should be expecting audio data,
# this is so what we avoid getting binary data from the websocket
# and falsely thinking it's audio data.
download = False
# audio_was_received indicates whether we have received audio data
# from the websocket. This is so we can raise an exception if we
# don't receive any audio data.
audio_was_received = False
# Each message needs to have the proper date # Each message needs to have the proper date
self.date = date_to_string() date = date_to_string()
# Prepare the request to be sent to the service. # Prepare the request to be sent to the service.
# #
@@ -290,26 +305,26 @@ class Communicate:
# #
# Also pay close attention to double { } in request (escape for f-string). # Also pay close attention to double { } in request (escape for f-string).
request = ( request = (
f"X-Timestamp:{self.date}\r\n" f"X-Timestamp:{date}\r\n"
"Content-Type:application/json; charset=utf-8\r\n" "Content-Type:application/json; charset=utf-8\r\n"
"Path:speech.config\r\n\r\n" "Path:speech.config\r\n\r\n"
'{"context":{"synthesis":{"audio":{"metadataoptions":{' '{"context":{"synthesis":{"audio":{"metadataoptions":{'
f'"sentenceBoundaryEnabled":false,' '"sentenceBoundaryEnabled":false,"wordBoundaryEnabled":true},'
f'"wordBoundaryEnabled":{word_boundary}}},"outputFormat":"{codec}"' f'"outputFormat":"{self.codec}"'
"}}}}\r\n" "}}}}\r\n"
) )
# Send the request to the service.
await websocket.send_str(request) await websocket.send_str(request)
# Send the message itself.
await websocket.send_str( await websocket.send_str(
ssml_headers_plus_data( ssml_headers_plus_data(
connect_id(), connect_id(),
self.date, date,
mkssml(message, voice, pitch, rate, volume), mkssml(
text, self.voice, self.pitch, self.rate, self.volume
),
) )
) )
# Begin listening for the response.
async for received in websocket: async for received in websocket:
if received.type == aiohttp.WSMsgType.TEXT: if received.type == aiohttp.WSMsgType.TEXT:
parameters, data = get_headers_and_data(received.data) parameters, data = get_headers_and_data(received.data)
@@ -329,35 +344,34 @@ class Communicate:
and parameters["Path"] == "audio.metadata" and parameters["Path"] == "audio.metadata"
): ):
metadata = json.loads(data) metadata = json.loads(data)
metadata_type = metadata["Metadata"][0]["Type"] for i in range(len(metadata["Metadata"])):
metadata_offset = metadata["Metadata"][0]["Data"][ metadata_type = metadata["Metadata"][i]["Type"]
"Offset" metadata_offset = metadata["Metadata"][i]["Data"][
] "Offset"
if metadata_type == "WordBoundary":
metadata_duration = metadata["Metadata"][0]["Data"][
"Duration"
] ]
metadata_text = metadata["Metadata"][0]["Data"][ if metadata_type == "WordBoundary":
"text" metadata_duration = metadata["Metadata"][i][
]["Text"] "Data"
yield ( ]["Duration"]
[ metadata_text = metadata["Metadata"][i]["Data"][
metadata_offset, "text"
metadata_duration, ]["Text"]
], yield {
metadata_text, "type": metadata_type,
None, "offset": metadata_offset,
) "duration": metadata_duration,
elif metadata_type == "SentenceBoundary": "text": metadata_text,
raise NotImplementedError( }
"SentenceBoundary is not supported due to being broken." elif metadata_type == "SentenceBoundary":
) raise UnknownResponse(
elif metadata_type == "SessionEnd": "SentenceBoundary is not supported due to being broken."
continue )
else: elif metadata_type == "SessionEnd":
raise NotImplementedError( continue
f"Unknown metadata type: {metadata_type}" else:
) raise UnknownResponse(
f"Unknown metadata type: {metadata_type}"
)
elif ( elif (
"Path" in parameters "Path" in parameters
and parameters["Path"] == "response" and parameters["Path"] == "response"
@@ -368,25 +382,60 @@ class Communicate:
Content-Type:application/json; charset=utf-8 Content-Type:application/json; charset=utf-8
Path:response Path:response
{"context":{"serviceTag":"yyyyyyyyyyyyyyyyyyy"},"audio":{"type":"inline","streamId":"zzzzzzzzzzzzzzzzz"}} {"context":{"serviceTag":"yyyyyyyyyyyyyyyyyyy"},"audio":
{"type":"inline","streamId":"zzzzzzzzzzzzzzzzz"}}
""" """
pass pass
else: else:
raise ValueError( raise UnknownResponse(
"The response from the service is not recognized.\n" "The response from the service is not recognized.\n"
+ received.data + received.data
) )
elif received.type == aiohttp.WSMsgType.BINARY: elif received.type == aiohttp.WSMsgType.BINARY:
if download: if download:
yield ( yield {
None, "type": "audio",
None, "data": b"Path:audio\r\n".join(
b"Path:audio\r\n".join(
received.data.split(b"Path:audio\r\n")[1:] received.data.split(b"Path:audio\r\n")[1:]
), ),
) }
audio_was_received = True
else: else:
raise ValueError( raise UnexpectedResponse(
"The service sent a binary message, but we are not expecting one." "The service sent a binary message, but we are not expecting one."
) )
await websocket.close()
if not audio_was_received:
raise NoAudioReceived(
"No audio was received from the service. Please verify that your parameters are correct."
)
async def save(
self, audio_fname: str | bytes, metadata_fname: Optional[str | bytes] = None
):
"""
Save the audio and metadata to the specified files.
"""
written_audio = False
try:
audio = open(audio_fname, "wb")
metadata = None
if metadata_fname is not None:
metadata = open(metadata_fname, "w")
async for message in self.stream():
if message["type"] == "audio":
audio.write(message["data"])
written_audio = True
elif metadata is not None and message["type"] == "WordBoundary":
json.dump(message, metadata)
metadata.write("\n")
finally:
audio.close()
if metadata is not None:
metadata.close()
if not written_audio:
raise NoAudioReceived(
"No audio was received from the service, so the file is empty."
)

View File

@@ -0,0 +1,13 @@
class UnknownResponse(Exception):
"""Raised when an unknown response is received from the server."""
class UnexpectedResponse(Exception):
"""Raised when an unexpected response is received from the server.
This hasn't happened yet, but it's possible that the server will
change its response format in the future."""
class NoAudioReceived(Exception):
"""Raised when no audio is received from the server."""

View File

@@ -28,9 +28,9 @@ def mktimestamp(time_unit):
Returns: Returns:
str: The timecode of the subtitle. str: The timecode of the subtitle.
""" """
hour = math.floor(time_unit / 10 ** 7 / 3600) hour = math.floor(time_unit / 10**7 / 3600)
minute = math.floor((time_unit / 10 ** 7 / 60) % 60) minute = math.floor((time_unit / 10**7 / 60) % 60)
seconds = (time_unit / 10 ** 7) % 60 seconds = (time_unit / 10**7) % 60
return f"{hour:02d}:{minute:02d}:{seconds:06.3f}" return f"{hour:02d}:{minute:02d}:{seconds:06.3f}"
@@ -48,7 +48,7 @@ class SubMaker:
subtitles should overlap. subtitles should overlap.
""" """
self.subs_and_offset = [] self.subs_and_offset = []
self.overlapping = overlapping * (10 ** 7) self.overlapping = overlapping * (10**7)
def create_sub(self, timestamp, text): def create_sub(self, timestamp, text):
""" """

View File

@@ -11,9 +11,6 @@ from edge_tts import Communicate, SubMaker, list_voices
async def _list_voices(proxy): async def _list_voices(proxy):
"""
List available voices.
"""
for idx, voice in enumerate(await list_voices(proxy=proxy)): for idx, voice in enumerate(await list_voices(proxy=proxy)):
if idx != 0: if idx != 0:
print() print()
@@ -26,34 +23,36 @@ async def _list_voices(proxy):
async def _tts(args): async def _tts(args):
tts = Communicate() tts = await Communicate(
subs = SubMaker(args.overlapping)
if args.write_media:
media_file = open(args.write_media, "wb") # pylint: disable=consider-using-with
async for i in tts.run(
args.text, args.text,
args.boundary_type,
args.codec,
args.voice, args.voice,
args.pitch,
args.rate,
args.volume,
proxy=args.proxy, proxy=args.proxy,
): rate=args.rate,
if i[2] is not None: volume=args.volume,
if not args.write_media: )
sys.stdout.buffer.write(i[2]) try:
else: media_file = None
media_file.write(i[2]) if args.write_media:
elif i[0] is not None and i[1] is not None: media_file = open(args.write_media, "wb")
subs.create_sub(i[0], i[1])
if args.write_media: subs = SubMaker(args.overlapping)
media_file.close() async for data in tts.stream():
if not args.write_subtitles: if data["type"] == "audio":
sys.stderr.write(subs.generate_subs()) if not args.write_media:
else: sys.stdout.buffer.write(data["data"])
with open(args.write_subtitles, "w", encoding="utf-8") as file: else:
file.write(subs.generate_subs()) media_file.write(data["data"])
elif data["type"] == "WordBoundary":
subs.create_sub([data["offset"], data["duration"]], data["text"])
if not args.write_subtitles:
sys.stderr.write(subs.generate_subs())
else:
with open(args.write_subtitles, "w", encoding="utf-8") as file:
file.write(subs.generate_subs())
finally:
if media_file is not None:
media_file.close()
async def _main(): async def _main():
@@ -64,23 +63,13 @@ async def _main():
parser.add_argument( parser.add_argument(
"-v", "-v",
"--voice", "--voice",
help="voice for TTS. " help="voice for TTS. " "Default: en-US-AriaNeural",
"Default: Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", default="en-US-AriaNeural",
default="Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)",
)
parser.add_argument(
"-c",
"--codec",
help="codec format. Default: audio-24khz-48kbitrate-mono-mp3. "
"Another choice is webm-24khz-16bit-mono-opus. "
"For more info check https://bit.ly/2T33h6S",
default="audio-24khz-48kbitrate-mono-mp3",
) )
group.add_argument( group.add_argument(
"-l", "-l",
"--list-voices", "--list-voices",
help="lists available voices. " help="lists available voices",
"Edge's list is incomplete so check https://bit.ly/2SFq1d3",
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
@@ -109,32 +98,19 @@ async def _main():
type=float, type=float,
) )
parser.add_argument( parser.add_argument(
"-b", "--write-media", help="send media output to file instead of stdout"
"--boundary-type",
help="set boundary type for subtitles. Default 0 for none. Set 1 for word_boundary.",
default=0,
type=int,
)
parser.add_argument(
"--write-media", help="instead of stdout, send media output to provided file"
) )
parser.add_argument( parser.add_argument(
"--write-subtitles", "--write-subtitles",
help="instead of stderr, send subtitle output to provided file (implies boundary-type is 1)", help="send subtitle output to provided file instead of stderr",
)
parser.add_argument(
"--proxy",
help="proxy",
) )
parser.add_argument("--proxy", help="use a proxy for TTS and voice list.")
args = parser.parse_args() args = parser.parse_args()
if args.list_voices: if args.list_voices:
await _list_voices(args.proxy) await _list_voices(args.proxy)
sys.exit(0) sys.exit(0)
if args.write_subtitles and args.boundary_type == 0:
args.boundary_type = 1
if args.text is not None or args.file is not None: if args.text is not None or args.file is not None:
if args.file is not None: if args.file is not None:
# we need to use sys.stdin.read() because some devices # we need to use sys.stdin.read() because some devices
@@ -151,9 +127,6 @@ async def _main():
def main(): def main():
"""
Main function.
"""
asyncio.get_event_loop().run_until_complete(_main()) asyncio.get_event_loop().run_until_complete(_main())