diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..ea05295 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: Lint + +on: + push: + paths: + - '*.py' + +jobs: + mypy: + runs-on: ubuntu-latest + steps: + - name: Setup Python + uses: actions/setup-python@v1 + with: + python-version: 3.7.4 + architecture: x64 + - name: Checkout + uses: actions/checkout@v1 + - name: Install mypy + run: pip install mypy + - name: Run mypy + uses: sasanquaneuf/mypy-github-action@releases/v1 + with: + checkName: 'mypy' # NOTE: this needs to be the same as the job name + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/examples/basic_generation.py b/examples/basic_generation.py new file mode 100644 index 0000000..6c973b9 --- /dev/null +++ b/examples/basic_generation.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +""" +Basic example of edge_tts usage. +""" + +import asyncio + +import edge_tts + + +async def main() -> None: + TEXT = "Hello World!" + VOICE = "en-GB-SoniaNeural" + OUTPUT_FILE = "test.mp3" + + communicate = edge_tts.Communicate(TEXT, VOICE) + await communicate.save(OUTPUT_FILE) + + +if __name__ == "__main__": + asyncio.get_event_loop().run_until_complete(main()) diff --git a/examples/dynamic_voice_selection.py b/examples/dynamic_voice_selection.py index fb85a1c..e7e67fb 100644 --- a/examples/dynamic_voice_selection.py +++ b/examples/dynamic_voice_selection.py @@ -1,26 +1,28 @@ +#!/usr/bin/env python3 + +""" +Example of dynamic voice selection using VoicesManager. +""" + import asyncio -import edge_tts -from edge_tts import VoicesManager import random -async def main(): - """ - Main function - """ +import edge_tts +from edge_tts import VoicesManager + + +async def main() -> None: voices = await VoicesManager.create() - voice = voices.find(Gender="Male", Language="es") + voice = voices.find(Gender="Male", Language="es") # Also supports Locales # voice = voices.find(Gender="Female", Locale="es-AR") VOICE = random.choice(voice)["ShortName"] TEXT = "Hoy es un buen día." OUTPUT_FILE = "spanish.mp3" - communicate = edge_tts.Communicate() + communicate = edge_tts.Communicate(TEXT, VOICE) + await communicate.save(OUTPUT_FILE) - with open(OUTPUT_FILE, "wb") as f: - async for i in communicate.run(TEXT, voice=VOICE): - if i[2] is not None: - f.write(i[2]) if __name__ == "__main__": asyncio.get_event_loop().run_until_complete(main()) diff --git a/examples/example.py b/examples/example.py deleted file mode 100644 index 14ce848..0000000 --- a/examples/example.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python3 -""" -Example Python script that shows how to use edge-tts as a module -""" - -import asyncio -import edge_tts - -async def main(): - """ - Main function - """ - TEXT = "Hello World!" - VOICE = "en-GB-SoniaNeural" - OUTPUT_FILE = "test.mp3" - - communicate = edge_tts.Communicate() - with open(OUTPUT_FILE, "wb") as f: - async for i in communicate.run(TEXT, voice=VOICE): - if i[2] is not None: - f.write(i[2]) - -if __name__ == "__main__": - asyncio.get_event_loop().run_until_complete(main()) \ No newline at end of file diff --git a/lint.sh b/lint.sh index b80309d..6532c9b 100755 --- a/lint.sh +++ b/lint.sh @@ -1,3 +1,4 @@ find src examples -name '*.py' | xargs black find src examples -name '*.py' | xargs isort find src examples -name '*.py' | xargs pylint +find src examples -name '*.py' | xargs mypy diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..01e108d --- /dev/null +++ b/mypy.ini @@ -0,0 +1,29 @@ +[mypy] +disallow_any_unimported = True +disallow_any_expr = False +disallow_any_decorated = True +disallow_any_explicit = False +disallow_any_generics = True +disallow_subclassing_any = True + +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = True + +implicit_optional = False +strict_optional = True + +warn_redundant_casts = True +warn_unused_ignores = True +warn_no_return = True +warn_return_any = True +warn_unreachable = True + +strict_concatenate = True +strict_equality = True +strict = True + +[mypy-edge_tts.list_voices] +disallow_any_decorated = False diff --git a/setup.cfg b/setup.cfg index 8e8cd4f..1bf386e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,4 +27,11 @@ where=src [options.entry_points] console_scripts = edge-tts = edge_tts.__main__:main - edge-playback = edge_playback.__init__:main + edge-playback = edge_playback.__main__:main + +[options.extras_require] +dev = + black + isort + mypy + pylint diff --git a/src/edge_playback/__init__.py b/src/edge_playback/__init__.py deleted file mode 100644 index 86df7c4..0000000 --- a/src/edge_playback/__init__.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python3 - -""" -Playback TTS with subtitles using edge-tts and mpv. -""" - -import os -import subprocess -import sys -import tempfile -from shutil import which - - -def main(): - """ - Main function. - """ - if which("mpv") and which("edge-tts"): - media = tempfile.NamedTemporaryFile(delete=False) - subtitle = tempfile.NamedTemporaryFile(delete=False) - try: - media.close() - subtitle.close() - - print() - print(f"Media file: {media.name}") - print(f"Subtitle file: {subtitle.name}\n") - with subprocess.Popen( - [ - "edge-tts", - "--boundary-type=1", - f"--write-media={media.name}", - f"--write-subtitles={subtitle.name}", - ] - + sys.argv[1:] - ) as process: - process.communicate() - - with subprocess.Popen( - [ - "mpv", - "--keep-open=yes", - f"--sub-file={subtitle.name}", - media.name, - ] - ) as process: - process.communicate() - finally: - os.unlink(media.name) - os.unlink(subtitle.name) - else: - print("This script requires mpv and edge-tts.") - - -if __name__ == "__main__": - main() diff --git a/src/edge_playback/__main__.py b/src/edge_playback/__main__.py index 2ac8c12..027e892 100644 --- a/src/edge_playback/__main__.py +++ b/src/edge_playback/__main__.py @@ -1,10 +1,63 @@ #!/usr/bin/env python3 """ -This is the main file for the edge_playback package. +Playback TTS with subtitles using edge-tts and mpv. """ -from edge_playback.__init__ import main +import os +import subprocess +import sys +import tempfile +from shutil import which + + +def main() -> None: + depcheck_failed = False + if not which("mpv"): + print("mpv is not installed.", file=sys.stderr) + depcheck_failed = True + if not which("edge-tts"): + print("edge-tts is not installed.", file=sys.stderr) + depcheck_failed = True + if depcheck_failed: + print("Please install the missing dependencies.", file=sys.stderr) + sys.exit(1) + + media = None + subtitle = None + try: + media = tempfile.NamedTemporaryFile(delete=False) + media.close() + + subtitle = tempfile.NamedTemporaryFile(delete=False) + subtitle.close() + + print(f"Media file: {media.name}") + print(f"Subtitle file: {subtitle.name}\n") + with subprocess.Popen( + [ + "edge-tts", + f"--write-media={media.name}", + f"--write-subtitles={subtitle.name}", + ] + + sys.argv[1:] + ) as process: + process.communicate() + + with subprocess.Popen( + [ + "mpv", + f"--sub-file={subtitle.name}", + media.name, + ] + ) as process: + process.communicate() + finally: + if media is not None: + os.unlink(media.name) + if subtitle is not None: + os.unlink(subtitle.name) + if __name__ == "__main__": main() diff --git a/src/edge_playback/py.typed b/src/edge_playback/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/edge_tts/__init__.py b/src/edge_tts/__init__.py index 24b4e77..8ea0ee7 100644 --- a/src/edge_tts/__init__.py +++ b/src/edge_tts/__init__.py @@ -3,5 +3,7 @@ __init__ for edge_tts """ from .communicate import Communicate -from .list_voices import list_voices, VoicesManager -from .submaker import SubMaker \ No newline at end of file +from .list_voices import VoicesManager, list_voices +from .submaker import SubMaker + +__all__ = ["Communicate", "VoicesManager", "list_voices", "SubMaker"] diff --git a/src/edge_tts/communicate.py b/src/edge_tts/communicate.py index f81d063..29bbb9a 100644 --- a/src/edge_tts/communicate.py +++ b/src/edge_tts/communicate.py @@ -4,16 +4,21 @@ Communicate package. import json +import re import time import uuid +from typing import Any, AsyncGenerator, Dict, Generator, List, Optional from xml.sax.saxutils import escape import aiohttp +from edge_tts.exceptions import (NoAudioReceived, UnexpectedResponse, + UnknownResponse) + from .constants import WSS_URL -def get_headers_and_data(data): +def get_headers_and_data(data: str | bytes) -> tuple[Dict[str, str], bytes]: """ Returns the headers and data from the given data. @@ -25,6 +30,8 @@ def get_headers_and_data(data): """ if isinstance(data, str): data = data.encode("utf-8") + if not isinstance(data, bytes): + raise TypeError("data must be str or bytes") headers = {} for line in data.split(b"\r\n\r\n")[0].split(b"\r\n"): @@ -37,7 +44,7 @@ def get_headers_and_data(data): return headers, b"\r\n\r\n".join(data.split(b"\r\n\r\n")[1:]) -def remove_incompatible_characters(string): +def remove_incompatible_characters(string: str | bytes) -> str: """ The service does not support a couple character ranges. Most important being the vertical tab character which is @@ -52,31 +59,30 @@ def remove_incompatible_characters(string): """ if isinstance(string, bytes): string = string.decode("utf-8") + if not isinstance(string, str): + raise TypeError("string must be str or bytes") - string = list(string) + chars: List[str] = list(string) - for idx, char in enumerate(string): - code = ord(char) + for idx, char in enumerate(chars): + code: int = ord(char) if (0 <= code <= 8) or (11 <= code <= 12) or (14 <= code <= 31): - string[idx] = " " + chars[idx] = " " - return "".join(string) + return "".join(chars) -def connect_id(): +def connect_id() -> str: """ Returns a UUID without dashes. - Args: - None - Returns: str: A UUID without dashes. """ return str(uuid.uuid4()).replace("-", "") -def iter_bytes(my_bytes): +def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]: """ Iterates over bytes object @@ -90,20 +96,22 @@ def iter_bytes(my_bytes): yield my_bytes[i : i + 1] -def split_text_by_byte_length(text, byte_length): +def split_text_by_byte_length(text: str | bytes, byte_length: int) -> List[bytes]: """ Splits a string into a list of strings of a given byte length while attempting to keep words together. Args: - text (byte): The string to be split. - byte_length (int): The byte length of each string in the list. + text (str or bytes): The string to be split. + byte_length (int): The maximum byte length of each string in the list. Returns: - list: A list of strings of the given byte length. + list: A list of bytes of the given byte length. """ if isinstance(text, str): text = text.encode("utf-8") + if not isinstance(text, bytes): + raise TypeError("text must be str or bytes") words = [] while len(text) > byte_length: @@ -125,17 +133,10 @@ def split_text_by_byte_length(text, byte_length): return words -def mkssml(text, voice, pitch, rate, volume): +def mkssml(text: str | bytes, voice: str, pitch: str, rate: str, volume: str) -> str: """ Creates a SSML string from the given parameters. - Args: - text (str): The text to be spoken. - voice (str): The voice to be used. - pitch (str): The pitch to be used. - rate (str): The rate to be used. - volume (str): The volume to be used. - Returns: str: The SSML string. """ @@ -150,13 +151,10 @@ def mkssml(text, voice, pitch, rate, volume): return ssml -def date_to_string(): +def date_to_string() -> str: """ Return Javascript-style date string. - Args: - None - Returns: str: Javascript-style date string. """ @@ -171,15 +169,10 @@ def date_to_string(): ) -def ssml_headers_plus_data(request_id, timestamp, ssml): +def ssml_headers_plus_data(request_id: str, timestamp: str, ssml: str) -> str: """ Returns the headers and data to be used in the request. - Args: - request_id (str): The request ID. - timestamp (str): The timestamp. - ssml (str): The SSML string. - Returns: str: The headers and data to be used in the request. """ @@ -198,73 +191,85 @@ class Communicate: Class for communicating with the service. """ - def __init__(self): - """ - Initializes the Communicate class. - """ - self.date = date_to_string() - - async def run( + def __init__( self, - messages, - boundary_type=0, - codec="audio-24khz-48kbitrate-mono-mp3", - voice="Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", - pitch="+0Hz", - rate="+0%", - volume="+0%", - proxy=None, + text: str, + voice: str = "Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", + *, + pitch: str = "+0Hz", + rate: str = "+0%", + volume: str = "+0%", + proxy: Optional[str] = None, ): """ - Runs the Communicate class. + Initializes the Communicate class. - Args: - messages (str or list): A list of SSML strings or a single text. - boundery_type (int): The type of boundary to use. 0 for none, 1 for word_boundary, 2 for sentence_boundary. - codec (str): The codec to use. - voice (str): The voice to use. - pitch (str): The pitch to use. - rate (str): The rate to use. - volume (str): The volume to use. - - Yields: - tuple: The subtitle offset, subtitle, and audio data. + Raises: + ValueError: If the voice is not valid. """ - - word_boundary = False - - if boundary_type > 0: - word_boundary = True - if boundary_type > 1: - raise ValueError( - "Invalid boundary type. SentenceBoundary is no longer supported." + self.text: str = text + self.codec: str = "audio-24khz-48kbitrate-mono-mp3" + self.voice: str = voice + # Possible values for voice are: + # - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural) + # - cy-GB-NiaNeural + # Always send the first variant as that is what Microsoft Edge does. + match = re.match(r"^([a-z]{2})-([A-Z]{2})-(.+Neural)$", voice) + if match is not None: + self.voice = ( + "Microsoft Server Speech Text to Speech Voice" + + f" ({match.group(1)}-{match.group(2)}, {match.group(3)})" ) - word_boundary = str(word_boundary).lower() + if ( + re.match( + r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$", + self.voice, + ) + is None + ): + raise ValueError(f"Invalid voice '{voice}'.") - websocket_max_size = 2 ** 16 + if re.match(r"^[+-]\d+Hz$", pitch) is None: + raise ValueError(f"Invalid pitch '{pitch}'.") + self.pitch: str = pitch + + if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", rate) is None: + raise ValueError(f"Invalid rate '{rate}'.") + self.rate: str = rate + + if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", volume) is None: + raise ValueError(f"Invalid volume '{volume}'.") + self.volume: str = volume + + self.proxy: Optional[str] = proxy + + async def stream(self) -> AsyncGenerator[Dict[str, Any], None]: + """Streams audio and metadata from the service.""" + + websocket_max_size = 2**16 overhead_per_message = ( len( ssml_headers_plus_data( - connect_id(), self.date, mkssml("", voice, pitch, rate, volume) + connect_id(), + date_to_string(), + mkssml("", self.voice, self.pitch, self.rate, self.volume), ) ) - + 50 - ) # margin of error - messages = split_text_by_byte_length( - escape(remove_incompatible_characters(messages)), + + 50 # margin of error + ) + texts = split_text_by_byte_length( + escape(remove_incompatible_characters(self.text)), websocket_max_size - overhead_per_message, ) - # Variables for the loop - download = False async with aiohttp.ClientSession(trust_env=True) as session: async with session.ws_connect( f"{WSS_URL}&ConnectionId={connect_id()}", compress=15, autoclose=True, autoping=True, - proxy=proxy, + proxy=self.proxy, headers={ "Pragma": "no-cache", "Cache-Control": "no-cache", @@ -275,9 +280,19 @@ class Communicate: " (KHTML, like Gecko) Chrome/91.0.4472.77 Safari/537.36 Edg/91.0.864.41", }, ) as websocket: - for message in messages: + for text in texts: + # download indicates whether we should be expecting audio data, + # this is so what we avoid getting binary data from the websocket + # and falsely thinking it's audio data. + download_audio = False + + # audio_was_received indicates whether we have received audio data + # from the websocket. This is so we can raise an exception if we + # don't receive any audio data. + audio_was_received = False + # Each message needs to have the proper date - self.date = date_to_string() + date = date_to_string() # Prepare the request to be sent to the service. # @@ -290,26 +305,26 @@ class Communicate: # # Also pay close attention to double { } in request (escape for f-string). request = ( - f"X-Timestamp:{self.date}\r\n" + f"X-Timestamp:{date}\r\n" "Content-Type:application/json; charset=utf-8\r\n" "Path:speech.config\r\n\r\n" '{"context":{"synthesis":{"audio":{"metadataoptions":{' - f'"sentenceBoundaryEnabled":false,' - f'"wordBoundaryEnabled":{word_boundary}}},"outputFormat":"{codec}"' + '"sentenceBoundaryEnabled":false,"wordBoundaryEnabled":true},' + f'"outputFormat":"{self.codec}"' "}}}}\r\n" ) - # Send the request to the service. await websocket.send_str(request) - # Send the message itself. + await websocket.send_str( ssml_headers_plus_data( connect_id(), - self.date, - mkssml(message, voice, pitch, rate, volume), + date, + mkssml( + text, self.voice, self.pitch, self.rate, self.volume + ), ) ) - # Begin listening for the response. async for received in websocket: if received.type == aiohttp.WSMsgType.TEXT: parameters, data = get_headers_and_data(received.data) @@ -317,76 +332,101 @@ class Communicate: "Path" in parameters and parameters["Path"] == "turn.start" ): - download = True + download_audio = True elif ( "Path" in parameters and parameters["Path"] == "turn.end" ): - download = False + download_audio = False break elif ( "Path" in parameters and parameters["Path"] == "audio.metadata" ): metadata = json.loads(data) - metadata_type = metadata["Metadata"][0]["Type"] - metadata_offset = metadata["Metadata"][0]["Data"][ - "Offset" - ] - if metadata_type == "WordBoundary": - metadata_duration = metadata["Metadata"][0]["Data"][ - "Duration" + for i in range(len(metadata["Metadata"])): + metadata_type = metadata["Metadata"][i]["Type"] + metadata_offset = metadata["Metadata"][i]["Data"][ + "Offset" ] - metadata_text = metadata["Metadata"][0]["Data"][ - "text" - ]["Text"] - yield ( - [ - metadata_offset, - metadata_duration, - ], - metadata_text, - None, - ) - elif metadata_type == "SentenceBoundary": - raise NotImplementedError( - "SentenceBoundary is not supported due to being broken." - ) - elif metadata_type == "SessionEnd": - continue - else: - raise NotImplementedError( - f"Unknown metadata type: {metadata_type}" - ) + if metadata_type == "WordBoundary": + metadata_duration = metadata["Metadata"][i][ + "Data" + ]["Duration"] + metadata_text = metadata["Metadata"][i]["Data"][ + "text" + ]["Text"] + yield { + "type": metadata_type, + "offset": metadata_offset, + "duration": metadata_duration, + "text": metadata_text, + } + elif metadata_type == "SentenceBoundary": + raise UnknownResponse( + "SentenceBoundary is not supported due to being broken." + ) + elif metadata_type == "SessionEnd": + continue + else: + raise UnknownResponse( + f"Unknown metadata type: {metadata_type}" + ) elif ( "Path" in parameters and parameters["Path"] == "response" ): - # TODO: implement this: - """ - X-RequestId:xxxxxxxxxxxxxxxxxxxxxxxxx - Content-Type:application/json; charset=utf-8 - Path:response - - {"context":{"serviceTag":"yyyyyyyyyyyyyyyyyyy"},"audio":{"type":"inline","streamId":"zzzzzzzzzzzzzzzzz"}} - """ pass else: - raise ValueError( + raise UnknownResponse( "The response from the service is not recognized.\n" + received.data ) elif received.type == aiohttp.WSMsgType.BINARY: - if download: - yield ( - None, - None, - b"Path:audio\r\n".join( + if download_audio: + yield { + "type": "audio", + "data": b"Path:audio\r\n".join( received.data.split(b"Path:audio\r\n")[1:] ), - ) + } + audio_was_received = True else: - raise ValueError( - "The service sent a binary message, but we are not expecting one." + raise UnexpectedResponse( + "We received a binary message, but we are not expecting one." ) - await websocket.close() + + if not audio_was_received: + raise NoAudioReceived( + "No audio was received. Please verify that your parameters are correct." + ) + + async def save( + self, audio_fname: str | bytes, metadata_fname: Optional[str | bytes] = None + ) -> None: + """ + Save the audio and metadata to the specified files. + """ + written_audio = False + try: + audio = open(audio_fname, "wb") + metadata = None + if metadata_fname is not None: + metadata = open(metadata_fname, "w", encoding="utf-8") + + async for message in self.stream(): + if message["type"] == "audio": + audio.write(message["data"]) + written_audio = True + elif metadata is not None and message["type"] == "WordBoundary": + json.dump(message, metadata) + metadata.write("\n") + finally: + audio.close() + if metadata is not None: + metadata.close() + + if not written_audio: + raise NoAudioReceived( + "No audio was received from the service, so the file is empty." + ) diff --git a/src/edge_tts/exceptions.py b/src/edge_tts/exceptions.py new file mode 100644 index 0000000..16dcc57 --- /dev/null +++ b/src/edge_tts/exceptions.py @@ -0,0 +1,16 @@ +"""Exceptions for the Edge TTS project.""" + + +class UnknownResponse(Exception): + """Raised when an unknown response is received from the server.""" + + +class UnexpectedResponse(Exception): + """Raised when an unexpected response is received from the server. + + This hasn't happened yet, but it's possible that the server will + change its response format in the future.""" + + +class NoAudioReceived(Exception): + """Raised when no audio is received from the server.""" diff --git a/src/edge_tts/list_voices.py b/src/edge_tts/list_voices.py index f1d50a3..9793b5d 100644 --- a/src/edge_tts/list_voices.py +++ b/src/edge_tts/list_voices.py @@ -1,21 +1,21 @@ """ -list_voices package. +list_voices package for edge_tts. """ import json +from typing import Any, Dict, List, Optional import aiohttp from .constants import VOICE_LIST -async def list_voices(proxy=None): +async def list_voices(*, proxy: Optional[str] = None) -> Any: """ List all available voices and their attributes. This pulls data from the URL used by Microsoft Edge to return a list of - all available voices. However many more experimental voices are available - than are listed here. (See https://aka.ms/csspeech/voicenames) + all available voices. Returns: dict: A dictionary of voice attributes. @@ -47,20 +47,32 @@ class VoicesManager: A class to find the correct voice based on their attributes. """ + def __init__(self) -> None: + self.voices: List[Dict[str, Any]] = [] + self.called_create: bool = False + @classmethod - async def create(cls): + async def create(cls: Any) -> "VoicesManager": + """ + Creates a VoicesManager object and populates it with all available voices. + """ self = VoicesManager() self.voices = await list_voices() self.voices = [ {**voice, **{"Language": voice["Locale"].split("-")[0]}} for voice in self.voices ] + self.called_create = True return self - def find(self, **kwargs): + def find(self, **kwargs: Any) -> List[Dict[str, Any]]: """ Finds all matching voices based on the provided attributes. """ + if not self.called_create: + raise RuntimeError( + "VoicesManager.find() called before VoicesManager.create()" + ) matching_voices = [ voice for voice in self.voices if kwargs.items() <= voice.items() diff --git a/src/edge_tts/submaker.py b/src/edge_tts/submaker.py index 6988518..03a04db 100644 --- a/src/edge_tts/submaker.py +++ b/src/edge_tts/submaker.py @@ -6,10 +6,11 @@ information provided by the service easier. """ import math +from typing import List, Tuple from xml.sax.saxutils import escape, unescape -def formatter(offset1, offset2, subdata): +def formatter(offset1: float, offset2: float, subdata: str) -> str: """ formatter returns the timecode and the text of the subtitle. """ @@ -19,7 +20,7 @@ def formatter(offset1, offset2, subdata): ) -def mktimestamp(time_unit): +def mktimestamp(time_unit: float) -> str: """ mktimestamp returns the timecode of the subtitle. @@ -28,9 +29,9 @@ def mktimestamp(time_unit): Returns: str: The timecode of the subtitle. """ - hour = math.floor(time_unit / 10 ** 7 / 3600) - minute = math.floor((time_unit / 10 ** 7 / 60) % 60) - seconds = (time_unit / 10 ** 7) % 60 + hour = math.floor(time_unit / 10**7 / 3600) + minute = math.floor((time_unit / 10**7 / 60) % 60) + seconds = (time_unit / 10**7) % 60 return f"{hour:02d}:{minute:02d}:{seconds:06.3f}" @@ -39,7 +40,7 @@ class SubMaker: SubMaker class """ - def __init__(self, overlapping=1): + def __init__(self, overlapping: int = 1) -> None: """ SubMaker constructor. @@ -47,10 +48,11 @@ class SubMaker: overlapping (int): The amount of time in seconds that the subtitles should overlap. """ - self.subs_and_offset = [] - self.overlapping = overlapping * (10 ** 7) + self.offset: List[Tuple[float, float]] = [] + self.subs: List[str] = [] + self.overlapping: int = overlapping * (10**7) - def create_sub(self, timestamp, text): + def create_sub(self, timestamp: Tuple[float, float], text: str) -> None: """ create_sub creates a subtitle with the given timestamp and text and adds it to the list of subtitles @@ -62,40 +64,39 @@ class SubMaker: Returns: None """ - timestamp[1] += timestamp[0] - self.subs_and_offset.append(timestamp) - self.subs_and_offset.append(text) + self.offset.append((timestamp[0], timestamp[0] + timestamp[1])) + self.subs.append(text) - def generate_subs(self): + def generate_subs(self) -> str: """ generate_subs generates the complete subtitle file. Returns: str: The complete subtitle file. """ - if len(self.subs_and_offset) >= 2: + if len(self.subs) == len(self.offset): data = "WEBVTT\r\n\r\n" - for offset, subs in zip( - self.subs_and_offset[::2], self.subs_and_offset[1::2] - ): + for offset, subs in zip(self.offset, self.subs): subs = unescape(subs) - subs = [subs[i : i + 79] for i in range(0, len(subs), 79)] + split_subs: List[str] = [ + subs[i : i + 79] for i in range(0, len(subs), 79) + ] - for i in range(len(subs) - 1): - sub = subs[i] + for i in range(len(split_subs) - 1): + sub = split_subs[i] split_at_word = True if sub[-1] == " ": - subs[i] = sub[:-1] + split_subs[i] = sub[:-1] split_at_word = False if sub[0] == " ": - subs[i] = sub[1:] + split_subs[i] = sub[1:] split_at_word = False if split_at_word: - subs[i] += "-" + split_subs[i] += "-" - subs = "\r\n".join(subs) + subs = "\r\n".join(split_subs) data += formatter(offset[0], offset[1] + self.overlapping, subs) return data diff --git a/src/edge_tts/util.py b/src/edge_tts/util.py index 6a4a29f..638a5dc 100644 --- a/src/edge_tts/util.py +++ b/src/edge_tts/util.py @@ -6,14 +6,14 @@ Main package. import argparse import asyncio import sys +from io import BufferedWriter +from typing import Any from edge_tts import Communicate, SubMaker, list_voices -async def _list_voices(proxy): - """ - List available voices. - """ +async def _print_voices(*, proxy: str) -> None: + """Print all available voices.""" for idx, voice in enumerate(await list_voices(proxy=proxy)): if idx != 0: print() @@ -25,38 +25,41 @@ async def _list_voices(proxy): print(f"{key}: {voice[key]}") -async def _tts(args): - tts = Communicate() - subs = SubMaker(args.overlapping) - if args.write_media: - media_file = open(args.write_media, "wb") # pylint: disable=consider-using-with - async for i in tts.run( +async def _run_tts(args: Any) -> None: + """Run TTS after parsing arguments from command line.""" + tts = Communicate( args.text, - args.boundary_type, - args.codec, args.voice, - args.pitch, - args.rate, - args.volume, proxy=args.proxy, - ): - if i[2] is not None: - if not args.write_media: - sys.stdout.buffer.write(i[2]) - else: - media_file.write(i[2]) - elif i[0] is not None and i[1] is not None: - subs.create_sub(i[0], i[1]) - if args.write_media: - media_file.close() - if not args.write_subtitles: - sys.stderr.write(subs.generate_subs()) - else: - with open(args.write_subtitles, "w", encoding="utf-8") as file: - file.write(subs.generate_subs()) + rate=args.rate, + volume=args.volume, + ) + try: + media_file = None + if args.write_media: + media_file = open(args.write_media, "wb") + + subs = SubMaker(args.overlapping) + async for data in tts.stream(): + if data["type"] == "audio": + if isinstance(media_file, BufferedWriter): + media_file.write(data["data"]) + else: + sys.stdout.buffer.write(data["data"]) + elif data["type"] == "WordBoundary": + subs.create_sub((data["offset"], data["duration"]), data["text"]) + + if not args.write_subtitles: + sys.stderr.write(subs.generate_subs()) + else: + with open(args.write_subtitles, "w", encoding="utf-8") as file: + file.write(subs.generate_subs()) + finally: + if media_file is not None: + media_file.close() -async def _main(): +async def _async_main() -> None: parser = argparse.ArgumentParser(description="Microsoft Edge TTS") group = parser.add_mutually_exclusive_group(required=True) group.add_argument("-t", "--text", help="what TTS will say") @@ -64,23 +67,13 @@ async def _main(): parser.add_argument( "-v", "--voice", - help="voice for TTS. " - "Default: Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", - default="Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", - ) - parser.add_argument( - "-c", - "--codec", - help="codec format. Default: audio-24khz-48kbitrate-mono-mp3. " - "Another choice is webm-24khz-16bit-mono-opus. " - "For more info check https://bit.ly/2T33h6S", - default="audio-24khz-48kbitrate-mono-mp3", + help="voice for TTS. " "Default: en-US-AriaNeural", + default="en-US-AriaNeural", ) group.add_argument( "-l", "--list-voices", - help="lists available voices. " - "Edge's list is incomplete so check https://bit.ly/2SFq1d3", + help="lists available voices", action="store_true", ) parser.add_argument( @@ -109,32 +102,19 @@ async def _main(): type=float, ) parser.add_argument( - "-b", - "--boundary-type", - help="set boundary type for subtitles. Default 0 for none. Set 1 for word_boundary.", - default=0, - type=int, - ) - parser.add_argument( - "--write-media", help="instead of stdout, send media output to provided file" + "--write-media", help="send media output to file instead of stdout" ) parser.add_argument( "--write-subtitles", - help="instead of stderr, send subtitle output to provided file (implies boundary-type is 1)", - ) - parser.add_argument( - "--proxy", - help="proxy", + help="send subtitle output to provided file instead of stderr", ) + parser.add_argument("--proxy", help="use a proxy for TTS and voice list.") args = parser.parse_args() if args.list_voices: - await _list_voices(args.proxy) + await _print_voices(proxy=args.proxy) sys.exit(0) - if args.write_subtitles and args.boundary_type == 0: - args.boundary_type = 1 - if args.text is not None or args.file is not None: if args.file is not None: # we need to use sys.stdin.read() because some devices @@ -147,14 +127,12 @@ async def _main(): with open(args.file, "r", encoding="utf-8") as file: args.text = file.read() - await _tts(args) + await _run_tts(args) -def main(): - """ - Main function. - """ - asyncio.get_event_loop().run_until_complete(_main()) +def main() -> None: + """Run the main function using asyncio.""" + asyncio.get_event_loop().run_until_complete(_async_main()) if __name__ == "__main__":