Feat: refactor STT service (#294)

* add stt hook interface

* fix crypto exported to browser

* refactor use-transcribe

* may use openai stt

* refactor: remove decprecated codes

* fix undefined method
This commit is contained in:
an-lee
2024-02-10 19:55:07 +08:00
committed by GitHub
parent a71671907e
commit bc22a5e2b4
21 changed files with 488 additions and 633 deletions

View File

@@ -12,6 +12,7 @@ import {
ScrollArea,
Button,
PingPoint,
toast,
} from "@renderer/components/ui";
import React, { useEffect, useContext, useState } from "react";
import { t } from "i18next";
@@ -19,6 +20,7 @@ import { LoaderIcon, CheckCircleIcon, MicIcon } from "lucide-react";
import {
DbProviderContext,
AppSettingsProviderContext,
AISettingsProviderContext,
} from "@renderer/context";
import { useTranscribe } from "@renderer/hooks";
@@ -32,6 +34,7 @@ export const MediaTranscription = (props: {
onSelectSegment?: (index: number) => void;
}) => {
const { addDblistener, removeDbListener } = useContext(DbProviderContext);
const { whisperConfig } = useContext(AISettingsProviderContext);
const { EnjoyApp } = useContext(AppSettingsProviderContext);
const {
transcription,
@@ -55,13 +58,19 @@ export const MediaTranscription = (props: {
setTranscribing(true);
setProgress(0);
transcribe({
mediaId,
mediaType,
mediaSrc: mediaUrl,
}).finally(() => {
setTranscribing(false);
});
try {
const { engine, model, result } = await transcribe(mediaUrl);
await EnjoyApp.transcriptions.update(transcription.id, {
state: "finished",
result,
engine,
model,
});
} catch (err) {
toast.error(err.message);
}
setTranscribing(false);
};
const fetchSegmentStats = async () => {
@@ -80,14 +89,16 @@ export const MediaTranscription = (props: {
generate();
}
EnjoyApp.transcriptions.onProgress((_, p: number) => {
if (p > 100) p = 100;
setProgress(p);
});
if (whisperConfig.service === "local") {
EnjoyApp.whisper.onProgress((_, p: number) => {
if (p > 100) p = 100;
setProgress(p);
});
}
return () => {
removeDbListener(fetchSegmentStats);
EnjoyApp.transcriptions.removeProgressListeners();
EnjoyApp.whisper.removeProgressListeners();
};
}, [mediaId, mediaType]);
@@ -114,7 +125,9 @@ export const MediaTranscription = (props: {
{transcribing || transcription.state === "processing" ? (
<>
<PingPoint colorClassName="bg-yellow-500" />
<div className="text-sm">{progress}%</div>
<div className="text-sm">
{whisperConfig.service === "local" && `${progress}%`}
</div>
</>
) : transcription.state === "finished" ? (
<CheckCircleIcon className="text-green-500 w-4 h-4" />

View File

@@ -77,6 +77,8 @@ export const WhisperSettings = () => {
t("azureSpeechToTextDescription")}
{whisperConfig?.service === "cloudflare" &&
t("cloudflareSpeechToTextDescription")}
{whisperConfig?.service === "openai" &&
t("openaiSpeechToTextDescription")}
</div>
</div>
@@ -94,6 +96,7 @@ export const WhisperSettings = () => {
<SelectItem value="local">{t("local")}</SelectItem>
<SelectItem value="azure">{t("azureAi")}</SelectItem>
<SelectItem value="cloudflare">{t("cloudflareAi")}</SelectItem>
<SelectItem value="openai">OpenAI</SelectItem>
</SelectContent>
</Select>

View File

@@ -1,3 +1,3 @@
export * from './use-transcode';
export * from './use-transcribe';
export * from './use-ai-command';
export * from './use-conversation';

View File

@@ -1,58 +0,0 @@
import { AppSettingsProviderContext } from "@renderer/context";
import { useContext } from "react";
import { toast } from "@renderer/components/ui";
import { t } from "i18next";
import { fetchFile } from "@ffmpeg/util";
export const useTranscribe = () => {
const { EnjoyApp, ffmpeg } = useContext(AppSettingsProviderContext);
const transcode = async (src: string, options?: string[]) => {
if (!ffmpeg?.loaded) return;
options = options || ["-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le"];
try {
const uri = new URL(src);
const input = uri.pathname.split("/").pop();
const output = input.replace(/\.[^/.]+$/, ".wav");
await ffmpeg.writeFile(input, await fetchFile(src));
await ffmpeg.exec(["-i", input, ...options, output]);
const data = await ffmpeg.readFile(output);
return new Blob([data], { type: "audio/wav" });
} catch (e) {
toast.error(t("transcodeError"));
}
};
const transcribe = async (params: {
mediaSrc: string;
mediaId: string;
mediaType: "Audio" | "Video";
}) => {
const { mediaSrc, mediaId, mediaType } = params;
const data = await transcode(mediaSrc);
let blob;
if (data) {
blob = {
type: data.type.split(";")[0],
arrayBuffer: await data.arrayBuffer(),
};
}
return EnjoyApp.transcriptions.process(
{
targetId: mediaId,
targetType: mediaType,
},
{
blob,
}
);
};
return {
transcode,
transcribe,
};
};

View File

@@ -0,0 +1,263 @@
import {
AppSettingsProviderContext,
AISettingsProviderContext,
} from "@renderer/context";
import OpenAI from "openai";
import { useContext } from "react";
import { toast } from "@renderer/components/ui";
import { t } from "i18next";
import { fetchFile } from "@ffmpeg/util";
import { AI_WORKER_ENDPOINT } from "@/constants";
import * as sdk from "microsoft-cognitiveservices-speech-sdk";
import axios from "axios";
import take from "lodash/take";
import sortedUniqBy from "lodash/sortedUniqBy";
import { groupTranscription, END_OF_WORD_REGEX, milisecondsToTimestamp } from "@/utils";
export const useTranscribe = () => {
const { EnjoyApp, ffmpeg, user, webApi } = useContext(
AppSettingsProviderContext
);
const { whisperConfig, openai } = useContext(AISettingsProviderContext);
const transcode = async (src: string, options?: string[]) => {
if (!ffmpeg?.loaded) return;
options = options || ["-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le"];
try {
const uri = new URL(src);
const input = uri.pathname.split("/").pop();
const output = input.replace(/\.[^/.]+$/, ".wav");
await ffmpeg.writeFile(input, await fetchFile(src));
await ffmpeg.exec(["-i", input, ...options, output]);
const data = await ffmpeg.readFile(output);
return new Blob([data], { type: "audio/wav" });
} catch (e) {
toast.error(t("transcodeError"));
}
};
const transcribe = async (
mediaSrc: string
): Promise<{
engine: string;
model: string;
result: TranscriptionResultSegmentGroupType[];
}> => {
const blob = await transcode(mediaSrc);
if (whisperConfig.service === "local") {
return transcribeByLocal(blob);
} else if (whisperConfig.service === "cloudflare") {
return transcribeByCloudflareAi(blob);
} else if (whisperConfig.service === "openai") {
return transcribeByOpenAi(blob);
} else if (whisperConfig.service === "azure") {
return transcribeByAzureAi(blob);
} else {
throw new Error(t("whisperServiceNotSupported"));
}
};
const transcribeByLocal = async (blob: Blob) => {
const res = await EnjoyApp.whisper.transcribe(
{
blob: {
type: blob.type.split(";")[0],
arrayBuffer: await blob.arrayBuffer(),
},
},
{
force: true,
extra: ["--prompt", `"Hello! Welcome to listen to this audio."`],
}
);
const result = groupTranscription(res.transcription);
return {
engine: "whisper",
model: res.model.type,
result,
};
};
const transcribeByOpenAi = async (blob: Blob) => {
if (!openai?.key) {
throw new Error(t("openaiKeyRequired"));
}
const client = new OpenAI({
apiKey: openai.key,
baseURL: openai.baseUrl,
dangerouslyAllowBrowser: true,
});
const res: {
words: {
word: string;
start: number;
end: number;
}[];
} = (await client.audio.transcriptions.create({
file: new File([blob], "audio.wav"),
model: "whisper-1",
response_format: "verbose_json",
timestamp_granularities: ["word"],
})) as any;
const transcription: TranscriptionResultSegmentType[] = res.words.map(
(word) => {
return {
offsets: {
from: word.start * 1000,
to: word.end * 1000,
},
timestamps: {
from: milisecondsToTimestamp(word.start * 1000),
to: milisecondsToTimestamp(word.end * 1000),
},
text: word.word,
};
}
);
const result = groupTranscription(transcription);
return {
engine: "openai",
model: "whisper-1",
result,
};
};
const transcribeByCloudflareAi = async (blob: Blob) => {
const res: CfWhipserOutputType = (
await axios.postForm(`${AI_WORKER_ENDPOINT}/audio/transcriptions`, blob, {
headers: {
Authorization: `Bearer ${user.accessToken}`,
},
timeout: 1000 * 60 * 5,
})
).data;
const transcription: TranscriptionResultSegmentType[] = res.words.map(
(word) => {
return {
offsets: {
from: word.start * 1000,
to: word.end * 1000,
},
timestamps: {
from: milisecondsToTimestamp(word.start * 1000),
to: milisecondsToTimestamp(word.end * 1000),
},
text: word.word,
};
}
);
const result = groupTranscription(transcription);
return {
engine: "cloudflare",
model: "@cf/openai/whisper",
result,
};
};
const transcribeByAzureAi = async (
blob: Blob
): Promise<{
engine: string;
model: string;
result: TranscriptionResultSegmentGroupType[];
}> => {
const { token, region } = await webApi.generateSpeechToken();
const config = sdk.SpeechConfig.fromAuthorizationToken(token, region);
const audioConfig = sdk.AudioConfig.fromWavFileInput(
new File([blob], "audio.wav")
);
// setting the recognition language to English.
config.speechRecognitionLanguage = "en-US";
config.requestWordLevelTimestamps();
config.outputFormat = sdk.OutputFormat.Detailed;
// create the speech recognizer.
const reco = new sdk.SpeechRecognizer(config, audioConfig);
let results: SpeechRecognitionResultType[] = [];
return new Promise((resolve, reject) => {
reco.recognizing = (_s, e) => {
console.log(e.result.text);
};
reco.recognized = (_s, e) => {
const json = e.result.properties.getProperty(
sdk.PropertyId.SpeechServiceResponse_JsonResult
);
const result = JSON.parse(json);
results = results.concat(result);
};
reco.canceled = (_s, e) => {
if (e.reason === sdk.CancellationReason.Error) {
return reject(new Error(e.errorDetails));
}
reco.stopContinuousRecognitionAsync();
};
reco.sessionStopped = (_s, _e) => {
reco.stopContinuousRecognitionAsync();
const transcription: TranscriptionResultSegmentType[] = [];
results.forEach((result) => {
const best = take(sortedUniqBy(result.NBest, "Confidence"), 1)[0];
const words = best.Display.trim().split(" ");
best.Words.map((word, index) => {
let text = word.Word;
if (words.length === best.Words.length) {
text = words[index];
}
if (
index === best.Words.length - 1 &&
!text.trim().match(END_OF_WORD_REGEX)
) {
text = text + ".";
}
transcription.push({
offsets: {
from: word.Offset / 1e4,
to: (word.Offset + word.Duration) / 1e4,
},
timestamps: {
from: milisecondsToTimestamp(word.Offset / 1e4),
to: milisecondsToTimestamp((word.Offset + word.Duration) * 1e4),
},
text,
});
});
});
resolve({
engine: "azure",
model: "whisper",
result: groupTranscription(transcription),
});
};
reco.startContinuousRecognitionAsync();
});
};
return {
transcode,
transcribe,
};
};