Merge pull request #31 from rany2/simplify

Simplify edge_tts library usage
This commit is contained in:
rany
2023-01-05 01:15:26 +02:00
committed by GitHub
16 changed files with 440 additions and 331 deletions

26
.github/workflows/lint.yml vendored Normal file
View File

@@ -0,0 +1,26 @@
name: Lint
on:
push:
paths:
- '*.py'
jobs:
mypy:
runs-on: ubuntu-latest
steps:
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.7.4
architecture: x64
- name: Checkout
uses: actions/checkout@v1
- name: Install mypy
run: pip install mypy
- name: Run mypy
uses: sasanquaneuf/mypy-github-action@releases/v1
with:
checkName: 'mypy' # NOTE: this needs to be the same as the job name
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -0,0 +1,22 @@
#!/usr/bin/env python3
"""
Basic example of edge_tts usage.
"""
import asyncio
import edge_tts
async def main() -> None:
TEXT = "Hello World!"
VOICE = "en-GB-SoniaNeural"
OUTPUT_FILE = "test.mp3"
communicate = edge_tts.Communicate(TEXT, VOICE)
await communicate.save(OUTPUT_FILE)
if __name__ == "__main__":
asyncio.get_event_loop().run_until_complete(main())

View File

@@ -1,12 +1,17 @@
#!/usr/bin/env python3
"""
Example of dynamic voice selection using VoicesManager.
"""
import asyncio import asyncio
import edge_tts
from edge_tts import VoicesManager
import random import random
async def main(): import edge_tts
""" from edge_tts import VoicesManager
Main function
"""
async def main() -> None:
voices = await VoicesManager.create() voices = await VoicesManager.create()
voice = voices.find(Gender="Male", Language="es") voice = voices.find(Gender="Male", Language="es")
# Also supports Locales # Also supports Locales
@@ -15,12 +20,9 @@ async def main():
TEXT = "Hoy es un buen día." TEXT = "Hoy es un buen día."
OUTPUT_FILE = "spanish.mp3" OUTPUT_FILE = "spanish.mp3"
communicate = edge_tts.Communicate() communicate = edge_tts.Communicate(TEXT, VOICE)
await communicate.save(OUTPUT_FILE)
with open(OUTPUT_FILE, "wb") as f:
async for i in communicate.run(TEXT, voice=VOICE):
if i[2] is not None:
f.write(i[2])
if __name__ == "__main__": if __name__ == "__main__":
asyncio.get_event_loop().run_until_complete(main()) asyncio.get_event_loop().run_until_complete(main())

View File

@@ -1,24 +0,0 @@
#!/usr/bin/env python3
"""
Example Python script that shows how to use edge-tts as a module
"""
import asyncio
import edge_tts
async def main():
"""
Main function
"""
TEXT = "Hello World!"
VOICE = "en-GB-SoniaNeural"
OUTPUT_FILE = "test.mp3"
communicate = edge_tts.Communicate()
with open(OUTPUT_FILE, "wb") as f:
async for i in communicate.run(TEXT, voice=VOICE):
if i[2] is not None:
f.write(i[2])
if __name__ == "__main__":
asyncio.get_event_loop().run_until_complete(main())

View File

@@ -1,3 +1,4 @@
find src examples -name '*.py' | xargs black find src examples -name '*.py' | xargs black
find src examples -name '*.py' | xargs isort find src examples -name '*.py' | xargs isort
find src examples -name '*.py' | xargs pylint find src examples -name '*.py' | xargs pylint
find src examples -name '*.py' | xargs mypy

29
mypy.ini Normal file
View File

@@ -0,0 +1,29 @@
[mypy]
disallow_any_unimported = True
disallow_any_expr = False
disallow_any_decorated = True
disallow_any_explicit = False
disallow_any_generics = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
check_untyped_defs = True
disallow_untyped_decorators = True
implicit_optional = False
strict_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_no_return = True
warn_return_any = True
warn_unreachable = True
strict_concatenate = True
strict_equality = True
strict = True
[mypy-edge_tts.list_voices]
disallow_any_decorated = False

View File

@@ -27,4 +27,11 @@ where=src
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
edge-tts = edge_tts.__main__:main edge-tts = edge_tts.__main__:main
edge-playback = edge_playback.__init__:main edge-playback = edge_playback.__main__:main
[options.extras_require]
dev =
black
isort
mypy
pylint

View File

@@ -1,56 +0,0 @@
#!/usr/bin/env python3
"""
Playback TTS with subtitles using edge-tts and mpv.
"""
import os
import subprocess
import sys
import tempfile
from shutil import which
def main():
"""
Main function.
"""
if which("mpv") and which("edge-tts"):
media = tempfile.NamedTemporaryFile(delete=False)
subtitle = tempfile.NamedTemporaryFile(delete=False)
try:
media.close()
subtitle.close()
print()
print(f"Media file: {media.name}")
print(f"Subtitle file: {subtitle.name}\n")
with subprocess.Popen(
[
"edge-tts",
"--boundary-type=1",
f"--write-media={media.name}",
f"--write-subtitles={subtitle.name}",
]
+ sys.argv[1:]
) as process:
process.communicate()
with subprocess.Popen(
[
"mpv",
"--keep-open=yes",
f"--sub-file={subtitle.name}",
media.name,
]
) as process:
process.communicate()
finally:
os.unlink(media.name)
os.unlink(subtitle.name)
else:
print("This script requires mpv and edge-tts.")
if __name__ == "__main__":
main()

View File

@@ -1,10 +1,63 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
This is the main file for the edge_playback package. Playback TTS with subtitles using edge-tts and mpv.
""" """
from edge_playback.__init__ import main import os
import subprocess
import sys
import tempfile
from shutil import which
def main() -> None:
depcheck_failed = False
if not which("mpv"):
print("mpv is not installed.", file=sys.stderr)
depcheck_failed = True
if not which("edge-tts"):
print("edge-tts is not installed.", file=sys.stderr)
depcheck_failed = True
if depcheck_failed:
print("Please install the missing dependencies.", file=sys.stderr)
sys.exit(1)
media = None
subtitle = None
try:
media = tempfile.NamedTemporaryFile(delete=False)
media.close()
subtitle = tempfile.NamedTemporaryFile(delete=False)
subtitle.close()
print(f"Media file: {media.name}")
print(f"Subtitle file: {subtitle.name}\n")
with subprocess.Popen(
[
"edge-tts",
f"--write-media={media.name}",
f"--write-subtitles={subtitle.name}",
]
+ sys.argv[1:]
) as process:
process.communicate()
with subprocess.Popen(
[
"mpv",
f"--sub-file={subtitle.name}",
media.name,
]
) as process:
process.communicate()
finally:
if media is not None:
os.unlink(media.name)
if subtitle is not None:
os.unlink(subtitle.name)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

View File

@@ -3,5 +3,7 @@ __init__ for edge_tts
""" """
from .communicate import Communicate from .communicate import Communicate
from .list_voices import list_voices, VoicesManager from .list_voices import VoicesManager, list_voices
from .submaker import SubMaker from .submaker import SubMaker
__all__ = ["Communicate", "VoicesManager", "list_voices", "SubMaker"]

View File

@@ -4,16 +4,21 @@ Communicate package.
import json import json
import re
import time import time
import uuid import uuid
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from xml.sax.saxutils import escape from xml.sax.saxutils import escape
import aiohttp import aiohttp
from edge_tts.exceptions import (NoAudioReceived, UnexpectedResponse,
UnknownResponse)
from .constants import WSS_URL from .constants import WSS_URL
def get_headers_and_data(data): def get_headers_and_data(data: str | bytes) -> tuple[Dict[str, str], bytes]:
""" """
Returns the headers and data from the given data. Returns the headers and data from the given data.
@@ -25,6 +30,8 @@ def get_headers_and_data(data):
""" """
if isinstance(data, str): if isinstance(data, str):
data = data.encode("utf-8") data = data.encode("utf-8")
if not isinstance(data, bytes):
raise TypeError("data must be str or bytes")
headers = {} headers = {}
for line in data.split(b"\r\n\r\n")[0].split(b"\r\n"): for line in data.split(b"\r\n\r\n")[0].split(b"\r\n"):
@@ -37,7 +44,7 @@ def get_headers_and_data(data):
return headers, b"\r\n\r\n".join(data.split(b"\r\n\r\n")[1:]) return headers, b"\r\n\r\n".join(data.split(b"\r\n\r\n")[1:])
def remove_incompatible_characters(string): def remove_incompatible_characters(string: str | bytes) -> str:
""" """
The service does not support a couple character ranges. The service does not support a couple character ranges.
Most important being the vertical tab character which is Most important being the vertical tab character which is
@@ -52,31 +59,30 @@ def remove_incompatible_characters(string):
""" """
if isinstance(string, bytes): if isinstance(string, bytes):
string = string.decode("utf-8") string = string.decode("utf-8")
if not isinstance(string, str):
raise TypeError("string must be str or bytes")
string = list(string) chars: List[str] = list(string)
for idx, char in enumerate(string): for idx, char in enumerate(chars):
code = ord(char) code: int = ord(char)
if (0 <= code <= 8) or (11 <= code <= 12) or (14 <= code <= 31): if (0 <= code <= 8) or (11 <= code <= 12) or (14 <= code <= 31):
string[idx] = " " chars[idx] = " "
return "".join(string) return "".join(chars)
def connect_id(): def connect_id() -> str:
""" """
Returns a UUID without dashes. Returns a UUID without dashes.
Args:
None
Returns: Returns:
str: A UUID without dashes. str: A UUID without dashes.
""" """
return str(uuid.uuid4()).replace("-", "") return str(uuid.uuid4()).replace("-", "")
def iter_bytes(my_bytes): def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]:
""" """
Iterates over bytes object Iterates over bytes object
@@ -90,20 +96,22 @@ def iter_bytes(my_bytes):
yield my_bytes[i : i + 1] yield my_bytes[i : i + 1]
def split_text_by_byte_length(text, byte_length): def split_text_by_byte_length(text: str | bytes, byte_length: int) -> List[bytes]:
""" """
Splits a string into a list of strings of a given byte length Splits a string into a list of strings of a given byte length
while attempting to keep words together. while attempting to keep words together.
Args: Args:
text (byte): The string to be split. text (str or bytes): The string to be split.
byte_length (int): The byte length of each string in the list. byte_length (int): The maximum byte length of each string in the list.
Returns: Returns:
list: A list of strings of the given byte length. list: A list of bytes of the given byte length.
""" """
if isinstance(text, str): if isinstance(text, str):
text = text.encode("utf-8") text = text.encode("utf-8")
if not isinstance(text, bytes):
raise TypeError("text must be str or bytes")
words = [] words = []
while len(text) > byte_length: while len(text) > byte_length:
@@ -125,17 +133,10 @@ def split_text_by_byte_length(text, byte_length):
return words return words
def mkssml(text, voice, pitch, rate, volume): def mkssml(text: str | bytes, voice: str, pitch: str, rate: str, volume: str) -> str:
""" """
Creates a SSML string from the given parameters. Creates a SSML string from the given parameters.
Args:
text (str): The text to be spoken.
voice (str): The voice to be used.
pitch (str): The pitch to be used.
rate (str): The rate to be used.
volume (str): The volume to be used.
Returns: Returns:
str: The SSML string. str: The SSML string.
""" """
@@ -150,13 +151,10 @@ def mkssml(text, voice, pitch, rate, volume):
return ssml return ssml
def date_to_string(): def date_to_string() -> str:
""" """
Return Javascript-style date string. Return Javascript-style date string.
Args:
None
Returns: Returns:
str: Javascript-style date string. str: Javascript-style date string.
""" """
@@ -171,15 +169,10 @@ def date_to_string():
) )
def ssml_headers_plus_data(request_id, timestamp, ssml): def ssml_headers_plus_data(request_id: str, timestamp: str, ssml: str) -> str:
""" """
Returns the headers and data to be used in the request. Returns the headers and data to be used in the request.
Args:
request_id (str): The request ID.
timestamp (str): The timestamp.
ssml (str): The SSML string.
Returns: Returns:
str: The headers and data to be used in the request. str: The headers and data to be used in the request.
""" """
@@ -198,73 +191,85 @@ class Communicate:
Class for communicating with the service. Class for communicating with the service.
""" """
def __init__(self): def __init__(
"""
Initializes the Communicate class.
"""
self.date = date_to_string()
async def run(
self, self,
messages, text: str,
boundary_type=0, voice: str = "Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)",
codec="audio-24khz-48kbitrate-mono-mp3", *,
voice="Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", pitch: str = "+0Hz",
pitch="+0Hz", rate: str = "+0%",
rate="+0%", volume: str = "+0%",
volume="+0%", proxy: Optional[str] = None,
proxy=None,
): ):
""" """
Runs the Communicate class. Initializes the Communicate class.
Args: Raises:
messages (str or list): A list of SSML strings or a single text. ValueError: If the voice is not valid.
boundery_type (int): The type of boundary to use. 0 for none, 1 for word_boundary, 2 for sentence_boundary.
codec (str): The codec to use.
voice (str): The voice to use.
pitch (str): The pitch to use.
rate (str): The rate to use.
volume (str): The volume to use.
Yields:
tuple: The subtitle offset, subtitle, and audio data.
""" """
self.text: str = text
word_boundary = False self.codec: str = "audio-24khz-48kbitrate-mono-mp3"
self.voice: str = voice
if boundary_type > 0: # Possible values for voice are:
word_boundary = True # - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural)
if boundary_type > 1: # - cy-GB-NiaNeural
raise ValueError( # Always send the first variant as that is what Microsoft Edge does.
"Invalid boundary type. SentenceBoundary is no longer supported." match = re.match(r"^([a-z]{2})-([A-Z]{2})-(.+Neural)$", voice)
if match is not None:
self.voice = (
"Microsoft Server Speech Text to Speech Voice"
+ f" ({match.group(1)}-{match.group(2)}, {match.group(3)})"
) )
word_boundary = str(word_boundary).lower() if (
re.match(
r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$",
self.voice,
)
is None
):
raise ValueError(f"Invalid voice '{voice}'.")
websocket_max_size = 2 ** 16 if re.match(r"^[+-]\d+Hz$", pitch) is None:
raise ValueError(f"Invalid pitch '{pitch}'.")
self.pitch: str = pitch
if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", rate) is None:
raise ValueError(f"Invalid rate '{rate}'.")
self.rate: str = rate
if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", volume) is None:
raise ValueError(f"Invalid volume '{volume}'.")
self.volume: str = volume
self.proxy: Optional[str] = proxy
async def stream(self) -> AsyncGenerator[Dict[str, Any], None]:
"""Streams audio and metadata from the service."""
websocket_max_size = 2**16
overhead_per_message = ( overhead_per_message = (
len( len(
ssml_headers_plus_data( ssml_headers_plus_data(
connect_id(), self.date, mkssml("", voice, pitch, rate, volume) connect_id(),
date_to_string(),
mkssml("", self.voice, self.pitch, self.rate, self.volume),
) )
) )
+ 50 + 50 # margin of error
) # margin of error )
messages = split_text_by_byte_length( texts = split_text_by_byte_length(
escape(remove_incompatible_characters(messages)), escape(remove_incompatible_characters(self.text)),
websocket_max_size - overhead_per_message, websocket_max_size - overhead_per_message,
) )
# Variables for the loop
download = False
async with aiohttp.ClientSession(trust_env=True) as session: async with aiohttp.ClientSession(trust_env=True) as session:
async with session.ws_connect( async with session.ws_connect(
f"{WSS_URL}&ConnectionId={connect_id()}", f"{WSS_URL}&ConnectionId={connect_id()}",
compress=15, compress=15,
autoclose=True, autoclose=True,
autoping=True, autoping=True,
proxy=proxy, proxy=self.proxy,
headers={ headers={
"Pragma": "no-cache", "Pragma": "no-cache",
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
@@ -275,9 +280,19 @@ class Communicate:
" (KHTML, like Gecko) Chrome/91.0.4472.77 Safari/537.36 Edg/91.0.864.41", " (KHTML, like Gecko) Chrome/91.0.4472.77 Safari/537.36 Edg/91.0.864.41",
}, },
) as websocket: ) as websocket:
for message in messages: for text in texts:
# download indicates whether we should be expecting audio data,
# this is so what we avoid getting binary data from the websocket
# and falsely thinking it's audio data.
download_audio = False
# audio_was_received indicates whether we have received audio data
# from the websocket. This is so we can raise an exception if we
# don't receive any audio data.
audio_was_received = False
# Each message needs to have the proper date # Each message needs to have the proper date
self.date = date_to_string() date = date_to_string()
# Prepare the request to be sent to the service. # Prepare the request to be sent to the service.
# #
@@ -290,26 +305,26 @@ class Communicate:
# #
# Also pay close attention to double { } in request (escape for f-string). # Also pay close attention to double { } in request (escape for f-string).
request = ( request = (
f"X-Timestamp:{self.date}\r\n" f"X-Timestamp:{date}\r\n"
"Content-Type:application/json; charset=utf-8\r\n" "Content-Type:application/json; charset=utf-8\r\n"
"Path:speech.config\r\n\r\n" "Path:speech.config\r\n\r\n"
'{"context":{"synthesis":{"audio":{"metadataoptions":{' '{"context":{"synthesis":{"audio":{"metadataoptions":{'
f'"sentenceBoundaryEnabled":false,' '"sentenceBoundaryEnabled":false,"wordBoundaryEnabled":true},'
f'"wordBoundaryEnabled":{word_boundary}}},"outputFormat":"{codec}"' f'"outputFormat":"{self.codec}"'
"}}}}\r\n" "}}}}\r\n"
) )
# Send the request to the service.
await websocket.send_str(request) await websocket.send_str(request)
# Send the message itself.
await websocket.send_str( await websocket.send_str(
ssml_headers_plus_data( ssml_headers_plus_data(
connect_id(), connect_id(),
self.date, date,
mkssml(message, voice, pitch, rate, volume), mkssml(
text, self.voice, self.pitch, self.rate, self.volume
),
) )
) )
# Begin listening for the response.
async for received in websocket: async for received in websocket:
if received.type == aiohttp.WSMsgType.TEXT: if received.type == aiohttp.WSMsgType.TEXT:
parameters, data = get_headers_and_data(received.data) parameters, data = get_headers_and_data(received.data)
@@ -317,76 +332,101 @@ class Communicate:
"Path" in parameters "Path" in parameters
and parameters["Path"] == "turn.start" and parameters["Path"] == "turn.start"
): ):
download = True download_audio = True
elif ( elif (
"Path" in parameters "Path" in parameters
and parameters["Path"] == "turn.end" and parameters["Path"] == "turn.end"
): ):
download = False download_audio = False
break break
elif ( elif (
"Path" in parameters "Path" in parameters
and parameters["Path"] == "audio.metadata" and parameters["Path"] == "audio.metadata"
): ):
metadata = json.loads(data) metadata = json.loads(data)
metadata_type = metadata["Metadata"][0]["Type"] for i in range(len(metadata["Metadata"])):
metadata_offset = metadata["Metadata"][0]["Data"][ metadata_type = metadata["Metadata"][i]["Type"]
"Offset" metadata_offset = metadata["Metadata"][i]["Data"][
] "Offset"
if metadata_type == "WordBoundary":
metadata_duration = metadata["Metadata"][0]["Data"][
"Duration"
] ]
metadata_text = metadata["Metadata"][0]["Data"][ if metadata_type == "WordBoundary":
"text" metadata_duration = metadata["Metadata"][i][
]["Text"] "Data"
yield ( ]["Duration"]
[ metadata_text = metadata["Metadata"][i]["Data"][
metadata_offset, "text"
metadata_duration, ]["Text"]
], yield {
metadata_text, "type": metadata_type,
None, "offset": metadata_offset,
) "duration": metadata_duration,
elif metadata_type == "SentenceBoundary": "text": metadata_text,
raise NotImplementedError( }
"SentenceBoundary is not supported due to being broken." elif metadata_type == "SentenceBoundary":
) raise UnknownResponse(
elif metadata_type == "SessionEnd": "SentenceBoundary is not supported due to being broken."
continue )
else: elif metadata_type == "SessionEnd":
raise NotImplementedError( continue
f"Unknown metadata type: {metadata_type}" else:
) raise UnknownResponse(
f"Unknown metadata type: {metadata_type}"
)
elif ( elif (
"Path" in parameters "Path" in parameters
and parameters["Path"] == "response" and parameters["Path"] == "response"
): ):
# TODO: implement this:
"""
X-RequestId:xxxxxxxxxxxxxxxxxxxxxxxxx
Content-Type:application/json; charset=utf-8
Path:response
{"context":{"serviceTag":"yyyyyyyyyyyyyyyyyyy"},"audio":{"type":"inline","streamId":"zzzzzzzzzzzzzzzzz"}}
"""
pass pass
else: else:
raise ValueError( raise UnknownResponse(
"The response from the service is not recognized.\n" "The response from the service is not recognized.\n"
+ received.data + received.data
) )
elif received.type == aiohttp.WSMsgType.BINARY: elif received.type == aiohttp.WSMsgType.BINARY:
if download: if download_audio:
yield ( yield {
None, "type": "audio",
None, "data": b"Path:audio\r\n".join(
b"Path:audio\r\n".join(
received.data.split(b"Path:audio\r\n")[1:] received.data.split(b"Path:audio\r\n")[1:]
), ),
) }
audio_was_received = True
else: else:
raise ValueError( raise UnexpectedResponse(
"The service sent a binary message, but we are not expecting one." "We received a binary message, but we are not expecting one."
) )
await websocket.close()
if not audio_was_received:
raise NoAudioReceived(
"No audio was received. Please verify that your parameters are correct."
)
async def save(
self, audio_fname: str | bytes, metadata_fname: Optional[str | bytes] = None
) -> None:
"""
Save the audio and metadata to the specified files.
"""
written_audio = False
try:
audio = open(audio_fname, "wb")
metadata = None
if metadata_fname is not None:
metadata = open(metadata_fname, "w", encoding="utf-8")
async for message in self.stream():
if message["type"] == "audio":
audio.write(message["data"])
written_audio = True
elif metadata is not None and message["type"] == "WordBoundary":
json.dump(message, metadata)
metadata.write("\n")
finally:
audio.close()
if metadata is not None:
metadata.close()
if not written_audio:
raise NoAudioReceived(
"No audio was received from the service, so the file is empty."
)

View File

@@ -0,0 +1,16 @@
"""Exceptions for the Edge TTS project."""
class UnknownResponse(Exception):
"""Raised when an unknown response is received from the server."""
class UnexpectedResponse(Exception):
"""Raised when an unexpected response is received from the server.
This hasn't happened yet, but it's possible that the server will
change its response format in the future."""
class NoAudioReceived(Exception):
"""Raised when no audio is received from the server."""

View File

@@ -1,21 +1,21 @@
""" """
list_voices package. list_voices package for edge_tts.
""" """
import json import json
from typing import Any, Dict, List, Optional
import aiohttp import aiohttp
from .constants import VOICE_LIST from .constants import VOICE_LIST
async def list_voices(proxy=None): async def list_voices(*, proxy: Optional[str] = None) -> Any:
""" """
List all available voices and their attributes. List all available voices and their attributes.
This pulls data from the URL used by Microsoft Edge to return a list of This pulls data from the URL used by Microsoft Edge to return a list of
all available voices. However many more experimental voices are available all available voices.
than are listed here. (See https://aka.ms/csspeech/voicenames)
Returns: Returns:
dict: A dictionary of voice attributes. dict: A dictionary of voice attributes.
@@ -47,20 +47,32 @@ class VoicesManager:
A class to find the correct voice based on their attributes. A class to find the correct voice based on their attributes.
""" """
def __init__(self) -> None:
self.voices: List[Dict[str, Any]] = []
self.called_create: bool = False
@classmethod @classmethod
async def create(cls): async def create(cls: Any) -> "VoicesManager":
"""
Creates a VoicesManager object and populates it with all available voices.
"""
self = VoicesManager() self = VoicesManager()
self.voices = await list_voices() self.voices = await list_voices()
self.voices = [ self.voices = [
{**voice, **{"Language": voice["Locale"].split("-")[0]}} {**voice, **{"Language": voice["Locale"].split("-")[0]}}
for voice in self.voices for voice in self.voices
] ]
self.called_create = True
return self return self
def find(self, **kwargs): def find(self, **kwargs: Any) -> List[Dict[str, Any]]:
""" """
Finds all matching voices based on the provided attributes. Finds all matching voices based on the provided attributes.
""" """
if not self.called_create:
raise RuntimeError(
"VoicesManager.find() called before VoicesManager.create()"
)
matching_voices = [ matching_voices = [
voice for voice in self.voices if kwargs.items() <= voice.items() voice for voice in self.voices if kwargs.items() <= voice.items()

View File

@@ -6,10 +6,11 @@ information provided by the service easier.
""" """
import math import math
from typing import List, Tuple
from xml.sax.saxutils import escape, unescape from xml.sax.saxutils import escape, unescape
def formatter(offset1, offset2, subdata): def formatter(offset1: float, offset2: float, subdata: str) -> str:
""" """
formatter returns the timecode and the text of the subtitle. formatter returns the timecode and the text of the subtitle.
""" """
@@ -19,7 +20,7 @@ def formatter(offset1, offset2, subdata):
) )
def mktimestamp(time_unit): def mktimestamp(time_unit: float) -> str:
""" """
mktimestamp returns the timecode of the subtitle. mktimestamp returns the timecode of the subtitle.
@@ -28,9 +29,9 @@ def mktimestamp(time_unit):
Returns: Returns:
str: The timecode of the subtitle. str: The timecode of the subtitle.
""" """
hour = math.floor(time_unit / 10 ** 7 / 3600) hour = math.floor(time_unit / 10**7 / 3600)
minute = math.floor((time_unit / 10 ** 7 / 60) % 60) minute = math.floor((time_unit / 10**7 / 60) % 60)
seconds = (time_unit / 10 ** 7) % 60 seconds = (time_unit / 10**7) % 60
return f"{hour:02d}:{minute:02d}:{seconds:06.3f}" return f"{hour:02d}:{minute:02d}:{seconds:06.3f}"
@@ -39,7 +40,7 @@ class SubMaker:
SubMaker class SubMaker class
""" """
def __init__(self, overlapping=1): def __init__(self, overlapping: int = 1) -> None:
""" """
SubMaker constructor. SubMaker constructor.
@@ -47,10 +48,11 @@ class SubMaker:
overlapping (int): The amount of time in seconds that the overlapping (int): The amount of time in seconds that the
subtitles should overlap. subtitles should overlap.
""" """
self.subs_and_offset = [] self.offset: List[Tuple[float, float]] = []
self.overlapping = overlapping * (10 ** 7) self.subs: List[str] = []
self.overlapping: int = overlapping * (10**7)
def create_sub(self, timestamp, text): def create_sub(self, timestamp: Tuple[float, float], text: str) -> None:
""" """
create_sub creates a subtitle with the given timestamp and text create_sub creates a subtitle with the given timestamp and text
and adds it to the list of subtitles and adds it to the list of subtitles
@@ -62,40 +64,39 @@ class SubMaker:
Returns: Returns:
None None
""" """
timestamp[1] += timestamp[0] self.offset.append((timestamp[0], timestamp[0] + timestamp[1]))
self.subs_and_offset.append(timestamp) self.subs.append(text)
self.subs_and_offset.append(text)
def generate_subs(self): def generate_subs(self) -> str:
""" """
generate_subs generates the complete subtitle file. generate_subs generates the complete subtitle file.
Returns: Returns:
str: The complete subtitle file. str: The complete subtitle file.
""" """
if len(self.subs_and_offset) >= 2: if len(self.subs) == len(self.offset):
data = "WEBVTT\r\n\r\n" data = "WEBVTT\r\n\r\n"
for offset, subs in zip( for offset, subs in zip(self.offset, self.subs):
self.subs_and_offset[::2], self.subs_and_offset[1::2]
):
subs = unescape(subs) subs = unescape(subs)
subs = [subs[i : i + 79] for i in range(0, len(subs), 79)] split_subs: List[str] = [
subs[i : i + 79] for i in range(0, len(subs), 79)
]
for i in range(len(subs) - 1): for i in range(len(split_subs) - 1):
sub = subs[i] sub = split_subs[i]
split_at_word = True split_at_word = True
if sub[-1] == " ": if sub[-1] == " ":
subs[i] = sub[:-1] split_subs[i] = sub[:-1]
split_at_word = False split_at_word = False
if sub[0] == " ": if sub[0] == " ":
subs[i] = sub[1:] split_subs[i] = sub[1:]
split_at_word = False split_at_word = False
if split_at_word: if split_at_word:
subs[i] += "-" split_subs[i] += "-"
subs = "\r\n".join(subs) subs = "\r\n".join(split_subs)
data += formatter(offset[0], offset[1] + self.overlapping, subs) data += formatter(offset[0], offset[1] + self.overlapping, subs)
return data return data

View File

@@ -6,14 +6,14 @@ Main package.
import argparse import argparse
import asyncio import asyncio
import sys import sys
from io import BufferedWriter
from typing import Any
from edge_tts import Communicate, SubMaker, list_voices from edge_tts import Communicate, SubMaker, list_voices
async def _list_voices(proxy): async def _print_voices(*, proxy: str) -> None:
""" """Print all available voices."""
List available voices.
"""
for idx, voice in enumerate(await list_voices(proxy=proxy)): for idx, voice in enumerate(await list_voices(proxy=proxy)):
if idx != 0: if idx != 0:
print() print()
@@ -25,38 +25,41 @@ async def _list_voices(proxy):
print(f"{key}: {voice[key]}") print(f"{key}: {voice[key]}")
async def _tts(args): async def _run_tts(args: Any) -> None:
tts = Communicate() """Run TTS after parsing arguments from command line."""
subs = SubMaker(args.overlapping) tts = Communicate(
if args.write_media:
media_file = open(args.write_media, "wb") # pylint: disable=consider-using-with
async for i in tts.run(
args.text, args.text,
args.boundary_type,
args.codec,
args.voice, args.voice,
args.pitch,
args.rate,
args.volume,
proxy=args.proxy, proxy=args.proxy,
): rate=args.rate,
if i[2] is not None: volume=args.volume,
if not args.write_media: )
sys.stdout.buffer.write(i[2]) try:
else: media_file = None
media_file.write(i[2]) if args.write_media:
elif i[0] is not None and i[1] is not None: media_file = open(args.write_media, "wb")
subs.create_sub(i[0], i[1])
if args.write_media: subs = SubMaker(args.overlapping)
media_file.close() async for data in tts.stream():
if not args.write_subtitles: if data["type"] == "audio":
sys.stderr.write(subs.generate_subs()) if isinstance(media_file, BufferedWriter):
else: media_file.write(data["data"])
with open(args.write_subtitles, "w", encoding="utf-8") as file: else:
file.write(subs.generate_subs()) sys.stdout.buffer.write(data["data"])
elif data["type"] == "WordBoundary":
subs.create_sub((data["offset"], data["duration"]), data["text"])
if not args.write_subtitles:
sys.stderr.write(subs.generate_subs())
else:
with open(args.write_subtitles, "w", encoding="utf-8") as file:
file.write(subs.generate_subs())
finally:
if media_file is not None:
media_file.close()
async def _main(): async def _async_main() -> None:
parser = argparse.ArgumentParser(description="Microsoft Edge TTS") parser = argparse.ArgumentParser(description="Microsoft Edge TTS")
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-t", "--text", help="what TTS will say") group.add_argument("-t", "--text", help="what TTS will say")
@@ -64,23 +67,13 @@ async def _main():
parser.add_argument( parser.add_argument(
"-v", "-v",
"--voice", "--voice",
help="voice for TTS. " help="voice for TTS. " "Default: en-US-AriaNeural",
"Default: Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)", default="en-US-AriaNeural",
default="Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)",
)
parser.add_argument(
"-c",
"--codec",
help="codec format. Default: audio-24khz-48kbitrate-mono-mp3. "
"Another choice is webm-24khz-16bit-mono-opus. "
"For more info check https://bit.ly/2T33h6S",
default="audio-24khz-48kbitrate-mono-mp3",
) )
group.add_argument( group.add_argument(
"-l", "-l",
"--list-voices", "--list-voices",
help="lists available voices. " help="lists available voices",
"Edge's list is incomplete so check https://bit.ly/2SFq1d3",
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
@@ -109,32 +102,19 @@ async def _main():
type=float, type=float,
) )
parser.add_argument( parser.add_argument(
"-b", "--write-media", help="send media output to file instead of stdout"
"--boundary-type",
help="set boundary type for subtitles. Default 0 for none. Set 1 for word_boundary.",
default=0,
type=int,
)
parser.add_argument(
"--write-media", help="instead of stdout, send media output to provided file"
) )
parser.add_argument( parser.add_argument(
"--write-subtitles", "--write-subtitles",
help="instead of stderr, send subtitle output to provided file (implies boundary-type is 1)", help="send subtitle output to provided file instead of stderr",
)
parser.add_argument(
"--proxy",
help="proxy",
) )
parser.add_argument("--proxy", help="use a proxy for TTS and voice list.")
args = parser.parse_args() args = parser.parse_args()
if args.list_voices: if args.list_voices:
await _list_voices(args.proxy) await _print_voices(proxy=args.proxy)
sys.exit(0) sys.exit(0)
if args.write_subtitles and args.boundary_type == 0:
args.boundary_type = 1
if args.text is not None or args.file is not None: if args.text is not None or args.file is not None:
if args.file is not None: if args.file is not None:
# we need to use sys.stdin.read() because some devices # we need to use sys.stdin.read() because some devices
@@ -147,14 +127,12 @@ async def _main():
with open(args.file, "r", encoding="utf-8") as file: with open(args.file, "r", encoding="utf-8") as file:
args.text = file.read() args.text = file.read()
await _tts(args) await _run_tts(args)
def main(): def main() -> None:
""" """Run the main function using asyncio."""
Main function. asyncio.get_event_loop().run_until_complete(_async_main())
"""
asyncio.get_event_loop().run_until_complete(_main())
if __name__ == "__main__": if __name__ == "__main__":