Files
everyone-can-use-english/enjoy/src/renderer/hooks/use-transcribe.tsx
an-lee d9523269a3 Fix stt hang up (#791)
* handle transcribe error

* refactor
2024-07-10 12:59:50 +08:00

249 lines
6.4 KiB
TypeScript

import {
AppSettingsProviderContext,
AISettingsProviderContext,
} from "@renderer/context";
import OpenAI from "openai";
import { useContext, useState } from "react";
import { t } from "i18next";
import { AI_WORKER_ENDPOINT } from "@/constants";
import * as sdk from "microsoft-cognitiveservices-speech-sdk";
import axios from "axios";
import { AlignmentResult } from "echogarden/dist/api/API.d.js";
import { useAiCommand } from "./use-ai-command";
export const useTranscribe = () => {
const { EnjoyApp, user, webApi } = useContext(AppSettingsProviderContext);
const { openai } = useContext(AISettingsProviderContext);
const { punctuateText } = useAiCommand();
const [output, setOutput] = useState<string>("");
const transcode = async (src: string | Blob): Promise<string> => {
if (src instanceof Blob) {
src = await EnjoyApp.cacheObjects.writeFile(
`${Date.now()}.${src.type.split("/")[1].split(";")[0]}`,
await src.arrayBuffer()
);
}
const output = await EnjoyApp.echogarden.transcode(src);
return output;
};
const transcribe = async (
mediaSrc: string,
params?: {
targetId?: string;
targetType?: string;
originalText?: string;
language: string;
service: WhisperConfigType["service"];
isolate?: boolean;
}
): Promise<{
engine: string;
model: string;
alignmentResult: AlignmentResult;
originalText?: string;
tokenId?: number;
}> => {
const url = await transcode(mediaSrc);
const {
targetId,
targetType,
originalText,
language,
service,
isolate = false,
} = params || {};
const blob = await (await fetch(url)).blob();
let result;
if (originalText) {
result = {
engine: "original",
model: "original",
};
} else if (service === "local") {
result = await transcribeByLocal(url, language);
} else if (service === "cloudflare") {
result = await transcribeByCloudflareAi(blob);
} else if (service === "openai") {
result = await transcribeByOpenAi(blob);
} else if (service === "azure") {
result = await transcribeByAzureAi(blob, language, {
targetId,
targetType,
});
} else {
throw new Error(t("whisperServiceNotSupported"));
}
setOutput(null);
let transcript = originalText || result.text;
// Remove all content inside `()`, `[]`, `{}` and trim the text
transcript = transcript
.replace(/\(.*?\)/g, "")
.replace(/\[.*?\]/g, "")
.replace(/\{.*?\}/g, "")
.trim();
// if the transcript does not contain any punctuation, use AI command to add punctuation
if (!transcript.match(/\w[.,!?](\s|$)/)) {
try {
transcript = await punctuateText(transcript);
} catch (err) {
console.warn(err.message);
}
}
const alignmentResult = await EnjoyApp.echogarden.align(
new Uint8Array(await blob.arrayBuffer()),
transcript,
{
language,
isolate,
}
);
return {
...result,
originalText,
alignmentResult,
};
};
const transcribeByLocal = async (url: string, language?: string) => {
const res = await EnjoyApp.whisper.transcribe(
{
file: url,
},
{
language,
force: true,
extra: ["--prompt", `"Hello! Welcome to listen to this audio."`],
}
);
return {
engine: "whisper",
model: res.model.type,
text: res.transcription.map((segment) => segment.text).join(" "),
};
};
const transcribeByOpenAi = async (blob: Blob) => {
if (!openai?.key) {
throw new Error(t("openaiKeyRequired"));
}
const client = new OpenAI({
apiKey: openai.key,
baseURL: openai.baseUrl,
dangerouslyAllowBrowser: true,
maxRetries: 0,
});
const res: { text: string } = (await client.audio.transcriptions.create({
file: new File([blob], "audio.wav"),
model: "whisper-1",
response_format: "json",
})) as any;
return {
engine: "openai",
model: "whisper-1",
text: res.text,
};
};
const transcribeByCloudflareAi = async (blob: Blob) => {
const res: CfWhipserOutputType = (
await axios.postForm(`${AI_WORKER_ENDPOINT}/audio/transcriptions`, blob, {
headers: {
Authorization: `Bearer ${user.accessToken}`,
},
timeout: 1000 * 60 * 5,
})
).data;
return {
engine: "cloudflare",
model: "@cf/openai/whisper",
text: res.text,
};
};
const transcribeByAzureAi = async (
blob: Blob,
language: string,
params?: {
targetId?: string;
targetType?: string;
}
): Promise<{
engine: string;
model: string;
text: string;
tokenId: number;
}> => {
const { id, token, region } = await webApi.generateSpeechToken(params);
const config = sdk.SpeechConfig.fromAuthorizationToken(token, region);
const audioConfig = sdk.AudioConfig.fromWavFileInput(
new File([blob], "audio.wav")
);
// setting the recognition language to learning language, such as 'en-US'.
config.speechRecognitionLanguage = language;
config.requestWordLevelTimestamps();
config.outputFormat = sdk.OutputFormat.Detailed;
// create the speech recognizer.
const reco = new sdk.SpeechRecognizer(config, audioConfig);
let results: SpeechRecognitionResultType[] = [];
return new Promise((resolve, reject) => {
reco.recognizing = (_s, e) => {
console.log(e.result);
setOutput(e.result.text);
};
reco.recognized = (_s, e) => {
const json = e.result.properties.getProperty(
sdk.PropertyId.SpeechServiceResponse_JsonResult
);
const result = JSON.parse(json);
results = results.concat(result);
};
reco.canceled = (_s, e) => {
if (e.reason === sdk.CancellationReason.Error) {
return reject(new Error(e.errorDetails));
}
reco.stopContinuousRecognitionAsync();
};
reco.sessionStopped = (_s, _e) => {
reco.stopContinuousRecognitionAsync();
resolve({
engine: "azure",
model: "whisper",
text: results.map((result) => result.DisplayText).join(" "),
tokenId: id,
});
};
reco.startContinuousRecognitionAsync();
});
};
return {
transcode,
transcribe,
output,
};
};