add more typing
This commit is contained in:
@@ -7,7 +7,7 @@ import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Generator, List, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
import aiohttp
|
||||
@@ -96,7 +96,7 @@ def iter_bytes(my_bytes: bytes) -> Generator[bytes, None, None]:
|
||||
yield my_bytes[i : i + 1]
|
||||
|
||||
|
||||
def split_text_by_byte_length(text: bytes, byte_length: int) -> List[bytes]:
|
||||
def split_text_by_byte_length(text: str | bytes, byte_length: int) -> List[bytes]:
|
||||
"""
|
||||
Splits a string into a list of strings of a given byte length
|
||||
while attempting to keep words together.
|
||||
@@ -151,7 +151,7 @@ def mkssml(text: str | bytes, voice: str, pitch: str, rate: str, volume: str) ->
|
||||
return ssml
|
||||
|
||||
|
||||
def date_to_string():
|
||||
def date_to_string() -> str:
|
||||
"""
|
||||
Return Javascript-style date string.
|
||||
|
||||
@@ -193,7 +193,7 @@ class Communicate:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text: str | List[str],
|
||||
text: str,
|
||||
voice: str = "Microsoft Server Speech Text to Speech Voice (en-US, AriaNeural)",
|
||||
*,
|
||||
pitch: str = "+0Hz",
|
||||
@@ -207,9 +207,9 @@ class Communicate:
|
||||
Raises:
|
||||
ValueError: If the voice is not valid.
|
||||
"""
|
||||
self.text = text
|
||||
self.codec = "audio-24khz-48kbitrate-mono-mp3"
|
||||
self.voice = voice
|
||||
self.text: str = text
|
||||
self.codec: str = "audio-24khz-48kbitrate-mono-mp3"
|
||||
self.voice: str = voice
|
||||
# Possible values for voice are:
|
||||
# - Microsoft Server Speech Text to Speech Voice (cy-GB, NiaNeural)
|
||||
# - cy-GB-NiaNeural
|
||||
@@ -232,19 +232,19 @@ class Communicate:
|
||||
|
||||
if re.match(r"^[+-]\d+Hz$", pitch) is None:
|
||||
raise ValueError(f"Invalid pitch '{pitch}'.")
|
||||
self.pitch = pitch
|
||||
self.pitch: str = pitch
|
||||
|
||||
if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", rate) is None:
|
||||
raise ValueError(f"Invalid rate '{rate}'.")
|
||||
self.rate = rate
|
||||
self.rate: str = rate
|
||||
|
||||
if re.match(r"^[+-]0*([0-9]|([1-9][0-9])|100)%$", volume) is None:
|
||||
raise ValueError(f"Invalid volume '{volume}'.")
|
||||
self.volume = volume
|
||||
self.volume: str = volume
|
||||
|
||||
self.proxy = proxy
|
||||
self.proxy: Optional[str] = proxy
|
||||
|
||||
async def stream(self):
|
||||
async def stream(self) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Streams audio and metadata from the service."""
|
||||
|
||||
websocket_max_size = 2**16
|
||||
@@ -403,7 +403,7 @@ class Communicate:
|
||||
|
||||
async def save(
|
||||
self, audio_fname: str | bytes, metadata_fname: Optional[str | bytes] = None
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Save the audio and metadata to the specified files.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user