160 lines
6.0 KiB
Python
160 lines
6.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
import configparser
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import threading
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
import requests
|
|
from gradio_client import Client, handle_file
|
|
from pydub import AudioSegment
|
|
from pydub.silence import split_on_silence
|
|
|
|
use_remote_api = False
|
|
process_workers = 5 if use_remote_api else 1
|
|
|
|
config = configparser.ConfigParser()
|
|
config.read('config.ini')
|
|
token = config['DEFAULT']['voice_token']
|
|
url = config['DEFAULT']['voice2txt_url']
|
|
headers = {
|
|
"Authorization": f'Bearer {token}'
|
|
}
|
|
|
|
|
|
# 配置日志
|
|
def configure_logging():
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
|
|
# 文件日志处理器
|
|
file_handler = logging.FileHandler('audio_transcription.log')
|
|
file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
file_handler.setFormatter(file_formatter)
|
|
logger.addHandler(file_handler)
|
|
|
|
# 控制台日志处理器
|
|
console_handler = logging.StreamHandler()
|
|
console_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
console_handler.setFormatter(console_formatter)
|
|
logger.addHandler(console_handler)
|
|
|
|
|
|
configure_logging()
|
|
|
|
|
|
def process_audio_file(audio_file_path):
|
|
file_name_with_extension = os.path.basename(audio_file_path)
|
|
file_name_without_extension = os.path.splitext(file_name_with_extension)[0]
|
|
logging.info(f"Starting processing {file_name_with_extension}")
|
|
# 获取 WAV 文件所在的目录
|
|
wav_dir = os.path.dirname(audio_file_path)
|
|
# 获取 MP4 文件的文件名(不包含扩展名)
|
|
wav_filename = os.path.splitext(os.path.basename(audio_file_path))[0]
|
|
# 生成对应的 WAV 文件路径
|
|
md_file = os.path.join(wav_dir, f"{wav_filename}.md")
|
|
# 创建一个锁对象
|
|
file_write_lock = threading.Lock()
|
|
|
|
try:
|
|
audio_file = AudioSegment.from_wav(audio_file_path)
|
|
audio_chunks = split_on_silence(audio_file, min_silence_len=500, silence_thresh=-40)
|
|
|
|
# 处理超过30秒的音频片段
|
|
new_audio_chunks = []
|
|
for chunk in audio_chunks:
|
|
if len(chunk) > 60000: # 60秒转换为毫秒
|
|
new_audio_chunks.extend(split_on_silence(chunk, min_silence_len=250, silence_thresh=-40))
|
|
else:
|
|
new_audio_chunks.append(chunk)
|
|
|
|
sentences = [] # 用于存储所有句子的列表
|
|
|
|
with ThreadPoolExecutor(max_workers=process_workers) as executor:
|
|
futures = {executor.submit(send_request, chunk, i, file_name_without_extension): i for i, chunk in
|
|
enumerate(new_audio_chunks)}
|
|
|
|
for future in as_completed(futures):
|
|
index = futures[future]
|
|
try:
|
|
result = future.result()
|
|
sentences.append(result) # 存储索引和文本
|
|
except Exception as exc:
|
|
logging.error(f'Request {index} for {file_name_with_extension} generated an exception: {str(exc)}')
|
|
|
|
# 按照音频的顺序写入Markdown文件
|
|
markdown_content = ""
|
|
for sentence in sorted(sentences, key=lambda x: x[0]): # 根据索引排序
|
|
markdown_content += f"{sentence[1]}\n\n"
|
|
|
|
# with file_write_lock: # 确保文件写入操作的线程安全
|
|
# md_file_path = os.path.join('media', file_name_without_extension + '.md')
|
|
with open(md_file, "w", encoding="utf-8") as f:
|
|
f.write(markdown_content)
|
|
|
|
logging.info(f"Finished processing {file_name_with_extension}")
|
|
except Exception as e:
|
|
logging.error(f"Failed to process {file_name_with_extension}: {str(e)}")
|
|
|
|
|
|
def send_request(chunk, index, file_name_without_extension):
|
|
audio_part_path = os.path.join('media', f"{file_name_without_extension}_chunk_{index}.wav")
|
|
chunk.export(audio_part_path, format="wav")
|
|
# logging.info(f'Exported chunk file {audio_part_path} for {file_name_without_extension}')
|
|
try:
|
|
if use_remote_api:
|
|
multipart_form_data = {
|
|
'file': (os.path.basename(audio_part_path), open(audio_part_path, 'rb')),
|
|
'model': (None, 'FunAudioLLM/SenseVoiceSmall')
|
|
}
|
|
|
|
response = requests.post(url, files=multipart_form_data, headers=headers)
|
|
result = response.json()
|
|
text = result["text"]
|
|
print(text)
|
|
|
|
else:
|
|
client = Client("http://192.168.31.3:7860/")
|
|
text = client.predict(
|
|
input_wav=handle_file(audio_part_path),
|
|
language="zh",
|
|
api_name="/model_inference"
|
|
)
|
|
print(text)
|
|
|
|
return index, text # 返回索引和文本
|
|
except Exception as e:
|
|
logging.error(f'Error processing {file_name_without_extension}, chunk {index}: {str(e)}')
|
|
# 将出错的音频片段复制到error文件夹
|
|
error_dir = os.path.join(os.getcwd(), 'media', 'error')
|
|
if not os.path.exists(error_dir):
|
|
os.makedirs(error_dir)
|
|
error_path = os.path.join(error_dir, f"{file_name_without_extension}_chunk_{index}.wav")
|
|
shutil.copy(audio_part_path, error_path)
|
|
logging.error(f'Copied request failed chunk file {file_name_without_extension} to {error_path}')
|
|
return index, "" # 返回空文本
|
|
finally:
|
|
if os.path.exists(audio_part_path):
|
|
os.remove(audio_part_path)
|
|
|
|
|
|
def main():
|
|
# all_files = os.listdir('media')
|
|
# audio_files = [file for file in all_files if file.endswith('.wav')]
|
|
audio_files = []
|
|
for root, dirs, files in os.walk('media'):
|
|
for file in files:
|
|
if file.endswith('.wav'):
|
|
audio_files.append(os.path.join(root, file))
|
|
print(audio_files)
|
|
|
|
with ThreadPoolExecutor(max_workers=process_workers) as executor:
|
|
for audio_file in audio_files:
|
|
audio_file_path = os.path.join(audio_file)
|
|
executor.submit(process_audio_file, audio_file_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|