Files
songyi/sense_voice_process.py
2025-04-15 09:11:14 +08:00

148 lines
4.4 KiB
Python

import os
import shutil
from concurrent.futures.thread import ThreadPoolExecutor
import logging
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from mpmath import convert
def configure_logging():
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# 文件日志处理器
file_handler = logging.FileHandler('audio_transcription.log')
file_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
# 控制台日志处理器
console_handler = logging.StreamHandler()
console_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)
configure_logging()
use_remote_api = False
process_workers = 5 if use_remote_api else 1
conerted = False
model_dir = "iic/SenseVoiceSmall"
model = AutoModel(
model=model_dir,
trust_remote_code=True,
remote_code="./model.py",
vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 30000},
device="cuda:0",
)
import os
from pydub import AudioSegment
def mp3_to_wav(mp3_path):
"""
将 MP3 文件转换为 WAV 格式,保存到同一目录下,返回 WAV 文件路径。
参数:
mp3_path (str): 原始 MP3 文件的路径。
返回:
str: 转换后的 WAV 文件路径。
"""
# 检查文件是否存在
if not os.path.isfile(mp3_path):
raise FileNotFoundError(f"文件未找到: {mp3_path}")
# 获取文件所在目录和文件名(不含扩展名)
folder = os.path.dirname(mp3_path)
filename_wo_ext = os.path.splitext(os.path.basename(mp3_path))[0]
# 构造 WAV 文件路径
wav_path = os.path.join(folder, f"{filename_wo_ext}.wav")
# 读取 MP3 并导出为 WAV
sound = AudioSegment.from_mp3(mp3_path)
sound.export(wav_path, format="wav")
return wav_path
def short_audio_process(audio_file_path):
print("logging file name:", audio_file_path)
if audio_file_path.endswith(".mp3"):
wav_path = mp3_to_wav(audio_file_path)
conerted = True
else:
wav_path = audio_file_path
res = model.generate(
input=wav_path,
cache={},
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=True, #
merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
if conerted:
os.remove(wav_path)
return text
def process_audio_file(audio_file_path):
file_name_with_extension = os.path.basename(audio_file_path)
file_name_without_extension = os.path.splitext(file_name_with_extension)[0]
logging.info(f"Starting processing {file_name_with_extension}")
# 获取 WAV 文件所在的目录
wav_dir = os.path.dirname(audio_file_path)
# 获取 MP4 文件的文件名(不包含扩展名)
wav_filename = os.path.splitext(os.path.basename(audio_file_path))[0]
# 生成对应的 WAV 文件路径
md_file = os.path.join(wav_dir, f"{wav_filename}.md")
# en
res = model.generate(
input=audio_file_path,
cache={},
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=True, #
merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
print(text)
# 按照音频的顺序写入Markdown文件
markdown_content = ""
markdown_content = text
# with file_write_lock: # 确保文件写入操作的线程安全
# md_file_path = os.path.join('media', file_name_without_extension + '.md')
with open(md_file, "w", encoding="utf-8") as f:
f.write(markdown_content)
def main():
# all_files = os.listdir('media')
# audio_files = [file for file in all_files if file.endswith('.wav')]
audio_files = []
for root, dirs, files in os.walk('media'):
for file in files:
if file.endswith('.wav'):
audio_files.append(os.path.join(root, file))
print(audio_files)
with ThreadPoolExecutor(max_workers=process_workers) as executor:
for audio_file in audio_files:
audio_file_path = os.path.join(audio_file)
executor.submit(process_audio_file, audio_file_path)
if __name__ == "__main__":
main()