add more typing

This commit is contained in:
rany2
2023-01-05 00:56:01 +02:00
parent efe0cbedde
commit c4c3dc5a13
12 changed files with 129 additions and 117 deletions

View File

@@ -9,7 +9,7 @@ import asyncio
import edge_tts import edge_tts
async def main(): async def main() -> None:
TEXT = "Hello World!" TEXT = "Hello World!"
VOICE = "en-GB-SoniaNeural" VOICE = "en-GB-SoniaNeural"
OUTPUT_FILE = "test.mp3" OUTPUT_FILE = "test.mp3"

View File

@@ -11,7 +11,7 @@ import edge_tts
from edge_tts import VoicesManager from edge_tts import VoicesManager
async def main(): async def main() -> None:
voices = await VoicesManager.create() voices = await VoicesManager.create()
voice = voices.find(Gender="Male", Language="es") voice = voices.find(Gender="Male", Language="es")
# Also supports Locales # Also supports Locales

View File

@@ -1,3 +1,4 @@
find src examples -name '*.py' | xargs black find src examples -name '*.py' | xargs black
find src examples -name '*.py' | xargs isort find src examples -name '*.py' | xargs isort
find src examples -name '*.py' | xargs pylint find src examples -name '*.py' | xargs pylint
find src examples -name '*.py' | xargs mypy

13
mypy.ini Normal file
View File

@@ -0,0 +1,13 @@
[mypy]
warn_return_any = True
warn_unused_configs = True
#disallow_any_unimported = True
#disallow_any_expr = True
#disallow_any_decorated = True
#disallow_any_explicit = True
#disallow_any_generics = True
#disallow_subclassing_any = True
#disallow_untyped_calls = True
disallow_untyped_defs = True
disallow_incomplete_defs = True

View File

@@ -27,4 +27,11 @@ where=src
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
edge-tts = edge_tts.__main__:main edge-tts = edge_tts.__main__:main
edge-playback = edge_playback.__init__:main edge-playback = edge_playback.__main__:main
[options.extras_require]
dev =
black
isort
mypy
pylint

View File

@@ -1,63 +0,0 @@
#!/usr/bin/env python3
"""
Playback TTS with subtitles using edge-tts and mpv.
"""
import os
import subprocess
import sys
import tempfile
from shutil import which
def main():
depcheck_failed = False
if not which("mpv"):
print("mpv is not installed.", file=sys.stderr)
depcheck_failed = True
if not which("edge-tts"):
print("edge-tts is not installed.", file=sys.stderr)
depcheck_failed = True
if depcheck_failed:
print("Please install the missing dependencies.", file=sys.stderr)
sys.exit(1)
media = None
subtitle = None
try:
media = tempfile.NamedTemporaryFile(delete=False)
media.close()
subtitle = tempfile.NamedTemporaryFile(delete=False)
subtitle.close()
print(f"Media file: {media.name}")
print(f"Subtitle file: {subtitle.name}\n")
with subprocess.Popen(
[
"edge-tts",
f"--write-media={media.name}",
f"--write-subtitles={subtitle.name}",
]
+ sys.argv[1:]
) as process:
process.communicate()
with subprocess.Popen(
[
"mpv",
f"--sub-file={subtitle.name}",
media.name,
]
) as process:
process.communicate()
finally:
if media is not None:
os.unlink(media.name)
if subtitle is not None:
os.unlink(subtitle.name)
if __name__ == "__main__":
main()

View File

@@ -1,10 +1,63 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
This is the main file for the edge_playback package. Playback TTS with subtitles using edge-tts and mpv.
""" """
from edge_playback.__init__ import main import os
import subprocess
import sys
import tempfile
from shutil import which
def main() -> None:
depcheck_failed = False
if not which("mpv"):
print("mpv is not installed.", file=sys.stderr)
depcheck_failed = True
if not which("edge-tts"):
print("edge-tts is not installed.", file=sys.stderr)
depcheck_failed = True
if depcheck_failed:
print("Please install the missing dependencies.", file=sys.stderr)
sys.exit(1)
media = None
subtitle = None
try:
media = tempfile.NamedTemporaryFile(delete=False)
media.close()
subtitle = tempfile.NamedTemporaryFile(delete=False)
subtitle.close()
print(f"Media file: {media.name}")
print(f"Subtitle file: {subtitle.name}\n")
with subprocess.Popen(
[
"edge-tts",
f"--write-media={media.name}",
f"--write-subtitles={subtitle.name}",
]
+ sys.argv[1:]
) as process:
process.communicate()
with subprocess.Popen(
[
"mpv",
f"--sub-file={subtitle.name}",
media.name,
]
) as process:
process.communicate()
finally:
if media is not None:
os.unlink(media.name)
if subtitle is not None:
os.unlink(subtitle.name)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

View File

@@ -7,7 +7,7 @@ import json
import re import re
import time import time
import uuid import uuid
from typing import Dict, Generator, List, Optional from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from xml.sax.saxutils import escape from xml.sax.saxutils import escape
import aiohttp import aiohttp
@@ -96,7 +96,7 @@ def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]:
yield my_bytes[i : i + 1] yield my_bytes[i : i + 1]
def split_text_by_byte_length(text: bytes, byte_length: int) -> List[bytes]: def split_text_by_byte_length(text: str | bytes, byte_length: int) -> List[bytes]:
""" """
Splits a string into a list of strings of a given byte length Splits a string into a list of strings of a given byte length
while attempting to keep words together. while attempting to keep words together.
@@ -151,7 +151,7 @@ def mkssml(text: str | bytes, voice: str, pitch: str, rate: str, volume: str) ->
return ssml return ssml
def date_to_string(): def date_to_string() -> str:
""" """
Return Javascript-style date string. Return Javascript-style date string.
@@ -193,7 +193,7 @@ class Communicate:
def __init__( def __init__(
self, self,
text: str | List[str], text: str,
voice: str = "Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", voice: str = "Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)",
*, *,
pitch: str = "+0Hz", pitch: str = "+0Hz",
@@ -207,9 +207,9 @@ class Communicate:
Raises: Raises:
ValueError: If the voice is not valid. ValueError: If the voice is not valid.
""" """
self.text = text self.text: str = text
self.codec = "audio-24khz-48kbitrate-mono-mp3" self.codec: str = "audio-24khz-48kbitrate-mono-mp3"
self.voice = voice self.voice: str = voice
# Possible values for voice are: # Possible values for voice are:
# - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural) # - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural)
# - cy-GB-NiaNeural # - cy-GB-NiaNeural
@@ -232,19 +232,19 @@ class Communicate:
if re.match(r"^[+-]\d+Hz$", pitch) is None: if re.match(r"^[+-]\d+Hz$", pitch) is None:
raise ValueError(f"Invalid pitch '{pitch}'.") raise ValueError(f"Invalid pitch '{pitch}'.")
self.pitch = pitch self.pitch: str = pitch
if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", rate) is None: if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", rate) is None:
raise ValueError(f"Invalid rate '{rate}'.") raise ValueError(f"Invalid rate '{rate}'.")
self.rate = rate self.rate: str = rate
if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", volume) is None: if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", volume) is None:
raise ValueError(f"Invalid volume '{volume}'.") raise ValueError(f"Invalid volume '{volume}'.")
self.volume = volume self.volume: str = volume
self.proxy = proxy self.proxy: Optional[str] = proxy
async def stream(self): async def stream(self) -> AsyncGenerator[Dict[str, Any], None]:
"""Streams audio and metadata from the service.""" """Streams audio and metadata from the service."""
websocket_max_size = 2**16 websocket_max_size = 2**16
@@ -403,7 +403,7 @@ class Communicate:
async def save( async def save(
self, audio_fname: str | bytes, metadata_fname: Optional[str | bytes] = None self, audio_fname: str | bytes, metadata_fname: Optional[str | bytes] = None
): ) -> None:
""" """
Save the audio and metadata to the specified files. Save the audio and metadata to the specified files.
""" """

View File

@@ -3,13 +3,14 @@ list_voices package for edge_tts.
""" """
import json import json
from typing import Any, Optional
import aiohttp import aiohttp
from .constants import VOICE_LIST from .constants import VOICE_LIST
async def list_voices(*, proxy=None): async def list_voices(*, proxy: Optional[str] = None) -> Any:
""" """
List all available voices and their attributes. List all available voices and their attributes.
@@ -47,7 +48,7 @@ class VoicesManager:
""" """
@classmethod @classmethod
async def create(cls): async def create(cls): # type: ignore
""" """
Creates a VoicesManager object and populates it with all available voices. Creates a VoicesManager object and populates it with all available voices.
""" """
@@ -59,12 +60,12 @@ class VoicesManager:
] ]
return self return self
def find(self, **kwargs): def find(self, **kwargs: Any) -> list[dict[str, Any]]:
""" """
Finds all matching voices based on the provided attributes. Finds all matching voices based on the provided attributes.
""" """
matching_voices = [ matching_voices = [
voice for voice in self.voices if kwargs.items() <= voice.items() voice for voice in self.voices if kwargs.items() <= voice.items() # type: ignore
] ]
return matching_voices return matching_voices

View File

@@ -6,10 +6,11 @@ information provided by the service easier.
""" """
import math import math
from typing import List, Tuple
from xml.sax.saxutils import escape, unescape from xml.sax.saxutils import escape, unescape
def formatter(offset1, offset2, subdata): def formatter(offset1: float, offset2: float, subdata: str) -> str:
""" """
formatter returns the timecode and the text of the subtitle. formatter returns the timecode and the text of the subtitle.
""" """
@@ -19,7 +20,7 @@ def formatter(offset1, offset2, subdata):
) )
def mktimestamp(time_unit): def mktimestamp(time_unit: float) -> str:
""" """
mktimestamp returns the timecode of the subtitle. mktimestamp returns the timecode of the subtitle.
@@ -39,7 +40,7 @@ class SubMaker:
SubMaker class SubMaker class
""" """
def __init__(self, overlapping=1): def __init__(self, overlapping: int = 1) -> None:
""" """
SubMaker constructor. SubMaker constructor.
@@ -47,10 +48,11 @@ class SubMaker:
overlapping (int): The amount of time in seconds that the overlapping (int): The amount of time in seconds that the
subtitles should overlap. subtitles should overlap.
""" """
self.subs_and_offset = [] self.offset: List[Tuple[float, float]] = []
self.overlapping = overlapping * (10**7) self.subs: List[str] = []
self.overlapping: int = overlapping * (10**7)
def create_sub(self, timestamp, text): def create_sub(self, timestamp: Tuple[float, float], text: str) -> None:
""" """
create_sub creates a subtitle with the given timestamp and text create_sub creates a subtitle with the given timestamp and text
and adds it to the list of subtitles and adds it to the list of subtitles
@@ -62,40 +64,37 @@ class SubMaker:
Returns: Returns:
None None
""" """
timestamp[1] += timestamp[0] self.offset.append((timestamp[0], timestamp[0] + timestamp[1]))
self.subs_and_offset.append(timestamp) self.subs.append(text)
self.subs_and_offset.append(text)
def generate_subs(self): def generate_subs(self) -> str:
""" """
generate_subs generates the complete subtitle file. generate_subs generates the complete subtitle file.
Returns: Returns:
str: The complete subtitle file. str: The complete subtitle file.
""" """
if len(self.subs_and_offset) >= 2: if len(self.subs) == len(self.offset):
data = "WEBVTT\r\n\r\n" data = "WEBVTT\r\n\r\n"
for offset, subs in zip( for offset, subs in zip(self.offset, self.subs):
self.subs_and_offset[::2], self.subs_and_offset[1::2]
):
subs = unescape(subs) subs = unescape(subs)
subs = [subs[i : i + 79] for i in range(0, len(subs), 79)] split_subs: List[str] = [subs[i : i + 79] for i in range(0, len(subs), 79)]
for i in range(len(subs) - 1): for i in range(len(split_subs) - 1):
sub = subs[i] sub = split_subs[i]
split_at_word = True split_at_word = True
if sub[-1] == " ": if sub[-1] == " ":
subs[i] = sub[:-1] split_subs[i] = sub[:-1]
split_at_word = False split_at_word = False
if sub[0] == " ": if sub[0] == " ":
subs[i] = sub[1:] split_subs[i] = sub[1:]
split_at_word = False split_at_word = False
if split_at_word: if split_at_word:
subs[i] += "-" split_subs[i] += "-"
subs = "\r\n".join(subs) subs = "\r\n".join(split_subs)
data += formatter(offset[0], offset[1] + self.overlapping, subs) data += formatter(offset[0], offset[1] + self.overlapping, subs)
return data return data

View File

@@ -5,12 +5,14 @@ Main package.
import argparse import argparse
import asyncio import asyncio
from io import BufferedWriter
import sys import sys
from typing import Any
from edge_tts import Communicate, SubMaker, list_voices from edge_tts import Communicate, SubMaker, list_voices
async def _print_voices(proxy): async def _print_voices(*, proxy: str) -> None:
"""Print all available voices.""" """Print all available voices."""
for idx, voice in enumerate(await list_voices(proxy=proxy)): for idx, voice in enumerate(await list_voices(proxy=proxy)):
if idx != 0: if idx != 0:
@@ -23,9 +25,9 @@ async def _print_voices(proxy):
print(f"{key}: {voice[key]}") print(f"{key}: {voice[key]}")
async def _run_tts(args): async def _run_tts(args: Any) -> None:
"""Run TTS after parsing arguments from command line.""" """Run TTS after parsing arguments from command line."""
tts = await Communicate( tts = Communicate(
args.text, args.text,
args.voice, args.voice,
proxy=args.proxy, proxy=args.proxy,
@@ -35,18 +37,17 @@ async def _run_tts(args):
try: try:
media_file = None media_file = None
if args.write_media: if args.write_media:
# pylint: disable=consider-using-with
media_file = open(args.write_media, "wb") media_file = open(args.write_media, "wb")
subs = SubMaker(args.overlapping) subs = SubMaker(args.overlapping)
async for data in tts.stream(): async for data in tts.stream():
if data["type"] == "audio": if data["type"] == "audio":
if not args.write_media: if isinstance(media_file, BufferedWriter):
sys.stdout.buffer.write(data["data"])
else:
media_file.write(data["data"]) media_file.write(data["data"])
else:
sys.stdout.buffer.write(data["data"])
elif data["type"] == "WordBoundary": elif data["type"] == "WordBoundary":
subs.create_sub([data["offset"], data["duration"]], data["text"]) subs.create_sub((data["offset"], data["duration"]), data["text"])
if not args.write_subtitles: if not args.write_subtitles:
sys.stderr.write(subs.generate_subs()) sys.stderr.write(subs.generate_subs())
@@ -58,7 +59,7 @@ async def _run_tts(args):
media_file.close() media_file.close()
async def _async_main(): async def _async_main() -> None:
parser = argparse.ArgumentParser(description="Microsoft Edge TTS") parser = argparse.ArgumentParser(description="Microsoft Edge TTS")
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-t", "--text", help="what TTS will say") group.add_argument("-t", "--text", help="what TTS will say")
@@ -111,7 +112,7 @@ async def _async_main():
args = parser.parse_args() args = parser.parse_args()
if args.list_voices: if args.list_voices:
await _print_voices(args.proxy) await _print_voices(proxy=args.proxy)
sys.exit(0) sys.exit(0)
if args.text is not None or args.file is not None: if args.text is not None or args.file is not None:
@@ -129,7 +130,7 @@ async def _async_main():
await _run_tts(args) await _run_tts(args)
def main(): def main() -> None:
"""Run the main function using asyncio.""" """Run the main function using asyncio."""
asyncio.get_event_loop().run_until_complete(_async_main()) asyncio.get_event_loop().run_until_complete(_async_main())