Feat: refactor STT service (#294)

* add stt hook interface

* fix crypto exported to browser

* refactor use-transcribe

* may use openai stt

* refactor: remove decprecated codes

* fix undefined method
This commit is contained in:
an-lee
2024-02-10 19:55:07 +08:00
committed by GitHub
parent a71671907e
commit bc22a5e2b4
21 changed files with 488 additions and 633 deletions

View File

@@ -317,6 +317,7 @@
"azureSpeechToTextDescription": "Use Azure AI Speech to transcribe. It is a paid service.",
"cloudflareAi": "Cloudflare AI",
"cloudflareSpeechToTextDescription": "Use Cloudflare AI Worker to transcribe. It is in beta and free for now.",
"openaiSpeechToTextDescription": "Use openAI to transcribe using your own key.",
"checkingWhisper": "Checking whisper status",
"pleaseDownloadWhisperModelFirst": "Please download whisper model first",
"whisperIsWorkingGood": "Whisper is working good",

View File

@@ -316,6 +316,7 @@
"azureSpeechToTextDescription": "使用 Azure AI Speech 进行语音转文本,收费服务",
"cloudflareAi": "Cloudflare AI",
"cloudflareSpeechToTextDescription": "使用 Cloudflare AI 进行语音转文本,目前免费",
"openaiSpeechToTextDescription": "使用 OpenAI 进行语音转文本(需要 API 密钥)",
"checkingWhisper": "正在检查 Whisper",
"pleaseDownloadWhisperModelFirst": "请先下载 Whisper 模型",
"whisperIsWorkingGood": "Whisper 正常工作",

View File

@@ -3,7 +3,7 @@ import { Speech } from "@main/db/models";
import fs from "fs-extra";
import path from "path";
import settings from "@main/settings";
import { hashFile } from "@/utils";
import { hashFile } from "@main/utils";
class SpeechesHandler {
private async create(

View File

@@ -1,7 +1,6 @@
import { ipcMain, IpcMainEvent } from "electron";
import { Transcription, Audio, Video } from "@main/db/models";
import { WhereOptions, Attributes } from "sequelize";
import { t } from "i18next";
import { Attributes } from "sequelize";
import log from "electron-log/main";
const logger = log.scope("db/handlers/transcriptions-handler");
@@ -44,7 +43,7 @@ class TranscriptionsHandler {
id: string,
params: Attributes<Transcription>
) {
const { result } = params;
const { result, engine, model, state } = params;
return Transcription.findOne({
where: { id },
@@ -53,63 +52,7 @@ class TranscriptionsHandler {
if (!transcription) {
throw new Error("models.transcription.notFound");
}
transcription.update({ result });
})
.catch((err) => {
logger.error(err);
event.sender.send("on-notification", {
type: "error",
message: err.message,
});
});
}
private async process(
event: IpcMainEvent,
where: WhereOptions<Attributes<Transcription>>,
options?: {
force?: boolean;
blob: {
type: string;
arrayBuffer: ArrayBuffer;
};
}
) {
const { force = true, blob } = options || {};
return Transcription.findOne({
where: {
...where,
},
})
.then((transcription) => {
if (!transcription) {
throw new Error("models.transcription.notFound");
}
const interval = setInterval(() => {
event.sender.send("on-notification", {
type: "warning",
message: t("stillTranscribing"),
});
}, 1000 * 10);
transcription
.process({
force,
wavFileBlob: blob,
onProgress: (progress: number) => {
event.sender.send("transcription-on-progress", progress);
},
})
.catch((err) => {
event.sender.send("on-notification", {
type: "error",
message: err.message,
});
})
.finally(() => {
clearInterval(interval);
});
transcription.update({ result, engine, model, state });
})
.catch((err) => {
logger.error(err);
@@ -122,7 +65,6 @@ class TranscriptionsHandler {
register() {
ipcMain.handle("transcriptions-find-or-create", this.findOrCreate);
ipcMain.handle("transcriptions-process", this.process);
ipcMain.handle("transcriptions-update", this.update);
}
}

View File

@@ -17,7 +17,7 @@ import {
import { Recording, Speech, Transcription, Video } from "@main/db/models";
import settings from "@main/settings";
import { AudioFormats, VideoFormats, WEB_API_URL } from "@/constants";
import { hashFile } from "@/utils";
import { hashFile } from "@main/utils";
import path from "path";
import fs from "fs-extra";
import { t } from "i18next";
@@ -191,15 +191,6 @@ export class Audio extends Model<Audio> {
}
}
@AfterCreate
static transcribeAsync(audio: Audio) {
if (settings.ffmpegConfig().ready) {
setTimeout(() => {
audio.transcribe();
}, 500);
}
}
@AfterCreate
static autoSync(audio: Audio) {
// auto sync should not block the main thread
@@ -332,38 +323,6 @@ export class Audio extends Model<Audio> {
});
}
// STT using whisper
async transcribe() {
Transcription.findOrCreate({
where: {
targetId: this.id,
targetType: "Audio",
},
defaults: {
targetId: this.id,
targetType: "Audio",
targetMd5: this.md5,
},
})
.then(([transcription, _created]) => {
if (transcription.state === "pending") {
transcription.process();
} else if (transcription.state === "finished") {
transcription.process({ force: true });
} else if (transcription.state === "processing") {
logger.warn(
`[${transcription.getDataValue("id")}]`,
"Transcription is processing."
);
}
})
.catch((err) => {
logger.error(err);
throw err;
});
}
static notify(audio: Audio, action: "create" | "update" | "destroy") {
if (!mainWindow.win) return;

View File

@@ -29,7 +29,7 @@ import fs from "fs-extra";
import path from "path";
import Ffmpeg from "@main/ffmpeg";
import whisper from "@main/whisper";
import { hashFile } from "@/utils";
import { hashFile } from "@main/utils";
import { WEB_API_URL } from "@/constants";
import proxyAgent from "@main/proxy-agent";

View File

@@ -19,7 +19,7 @@ import { Audio, PronunciationAssessment, Video } from "@main/db/models";
import fs from "fs-extra";
import path from "path";
import settings from "@main/settings";
import { hashFile } from "@/utils";
import { hashFile } from "@main/utils";
import log from "electron-log/main";
import storage from "@main/storage";
import { Client } from "@/api";

View File

@@ -20,7 +20,7 @@ import path from "path";
import settings from "@main/settings";
import OpenAI, { type ClientOptions } from "openai";
import { t } from "i18next";
import { hashFile } from "@/utils";
import { hashFile } from "@main/utils";
import { Audio, Message } from "@main/db/models";
import log from "electron-log/main";
import { WEB_API_URL } from "@/constants";

View File

@@ -1,5 +1,4 @@
import {
AfterCreate,
AfterUpdate,
AfterDestroy,
AfterFind,
@@ -13,18 +12,13 @@ import {
Unique,
} from "sequelize-typescript";
import { Audio, Video } from "@main/db/models";
import whisper from "@main/whisper";
import mainWindow from "@main/window";
import log from "electron-log/main";
import { Client } from "@/api";
import { WEB_API_URL, PROCESS_TIMEOUT } from "@/constants";
import settings from "@main/settings";
import Ffmpeg from "@main/ffmpeg";
import path from "path";
import fs from "fs-extra";
const logger = log.scope("db/models/transcription");
@Table({
modelName: "Transcription",
tableName: "transcriptions",
@@ -80,120 +74,13 @@ export class Transcription extends Model<Transcription> {
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
logger,
});
return webApi.syncTranscription(this.toJSON()).then(() => {
this.update({ syncedAt: new Date() });
});
}
// STT using whisper
async process(
options: {
force?: boolean;
wavFileBlob?: { type: string; arrayBuffer: ArrayBuffer };
onProgress?: (progress: number) => void;
} = {}
) {
if (this.getDataValue("state") === "processing") return;
const { force = false, wavFileBlob, onProgress } = options;
logger.info(`[${this.getDataValue("id")}]`, "Start to transcribe.");
let filePath = "";
if (this.targetType === "Audio") {
filePath = (await Audio.findByPk(this.targetId)).filePath;
} else if (this.targetType === "Video") {
filePath = (await Video.findByPk(this.targetId)).filePath;
}
if (!filePath) {
logger.error(`[${this.getDataValue("id")}]`, "No file path.");
throw new Error("No file path.");
}
let wavFile: string = filePath;
const tmpDir = settings.cachePath();
const outputFile = path.join(
tmpDir,
path.basename(filePath, path.extname(filePath)) + ".wav"
);
if (wavFileBlob) {
const format = wavFileBlob.type.split("/")[1];
if (format !== "wav") {
throw new Error("Only wav format is supported");
}
await fs.outputFile(outputFile, Buffer.from(wavFileBlob.arrayBuffer));
wavFile = outputFile;
} else if (settings.ffmpegConfig().ready) {
const ffmpeg = new Ffmpeg();
try {
wavFile = await ffmpeg.prepareForWhisper(
filePath,
path.join(
tmpDir,
path.basename(filePath, path.extname(filePath)) + ".wav"
)
);
} catch (err) {
logger.error("ffmpeg error", err);
}
}
try {
await this.update({
state: "processing",
});
const {
engine = "whisper",
model,
transcription,
} = await whisper.transcribe(wavFile, {
force,
extra: [
"--split-on-word",
"--max-len",
"1",
"--prompt",
`"Hello! Welcome to listen to this audio."`,
],
onProgress,
});
const result = whisper.groupTranscription(transcription);
this.update({
engine,
model: model?.type,
result,
state: "finished",
}).then(() => this.sync());
logger.info(`[${this.getDataValue("id")}]`, "Transcription finished.");
} catch (err) {
logger.error(
`[${this.getDataValue("id")}]`,
"Transcription not finished.",
err
);
this.update({
state: "pending",
});
throw err;
}
}
@AfterCreate
static startTranscribeAsync(transcription: Transcription) {
setTimeout(() => {
transcription.process();
}, 0);
}
@AfterUpdate
static notifyForUpdate(transcription: Transcription) {
this.notify(transcription, "update");

View File

@@ -17,7 +17,7 @@ import {
import { Audio, Recording, Speech, Transcription } from "@main/db/models";
import settings from "@main/settings";
import { AudioFormats, VideoFormats, WEB_API_URL } from "@/constants";
import { hashFile } from "@/utils";
import { hashFile } from "@main/utils";
import path from "path";
import fs from "fs-extra";
import { t } from "i18next";
@@ -213,15 +213,6 @@ export class Video extends Model<Video> {
}
}
@AfterCreate
static transcribeAsync(video: Video) {
if (settings.ffmpegConfig().ready) {
setTimeout(() => {
video.transcribe();
}, 500);
}
}
@AfterCreate
static autoSync(video: Video) {
// auto sync should not block the main thread
@@ -355,37 +346,6 @@ export class Video extends Model<Video> {
});
}
async transcribe() {
Transcription.findOrCreate({
where: {
targetId: this.id,
targetType: "Video",
},
defaults: {
targetId: this.id,
targetType: "Video",
targetMd5: this.md5,
},
})
.then(([transcription, _created]) => {
if (transcription.state === "pending") {
transcription.process();
} else if (transcription.state === "finished") {
transcription.process({ force: true });
} else if (transcription.state === "processing") {
logger.warn(
`[${transcription.getDataValue("id")}]`,
"Transcription is processing."
);
}
})
.catch((err) => {
logger.error(err);
throw err;
});
}
static notify(video: Video, action: "create" | "update" | "destroy") {
if (!mainWindow.win) return;

38
enjoy/src/main/utils.ts Normal file
View File

@@ -0,0 +1,38 @@
import { createHash } from "crypto";
import { createReadStream } from "fs";
export function hashFile(
path: string,
options: { algo: string }
): Promise<string> {
const algo = options.algo || "md5";
return new Promise((resolve, reject) => {
const hash = createHash(algo);
const stream = createReadStream(path);
stream.on("error", reject);
stream.on("data", (chunk) => hash.update(chunk));
stream.on("end", () => resolve(hash.digest("hex")));
});
}
export function hashBlob(
blob: Blob,
options: { algo: string }
): Promise<string> {
const algo = options.algo || "md5";
return new Promise((resolve, reject) => {
const hash = createHash(algo);
const reader = new FileReader();
reader.onload = () => {
if (reader.result instanceof ArrayBuffer) {
const buffer = Buffer.from(reader.result);
hash.update(buffer);
resolve(hash.digest("hex"));
} else {
reject(new Error("Unexpected result from FileReader"));
}
};
reader.onerror = reject;
reader.readAsArrayBuffer(blob);
});
}

View File

@@ -152,162 +152,17 @@ class Whipser {
});
}
async transcribeBlob(
blob: { type: string; arrayBuffer: ArrayBuffer },
options?: {
prompt?: string;
group?: boolean;
}
): Promise<
TranscriptionResultSegmentType[] | TranscriptionResultSegmentGroupType[]
> {
const { prompt, group = false } = options || {};
const format = blob.type.split("/")[1];
if (format !== "wav") {
throw new Error("Only wav format is supported");
}
const tempfile = path.join(settings.cachePath(), `${Date.now()}.${format}`);
await fs.outputFile(tempfile, Buffer.from(blob.arrayBuffer));
const extra = [];
if (prompt) {
extra.push(`--prompt "${prompt.replace(/"/g, '\\"')}"`);
}
const { transcription } = await this.transcribe(tempfile, {
force: true,
extra,
});
if (group) {
return this.groupTranscription(transcription);
} else {
return transcription;
}
}
async transcribe(
file: string,
options?: {
force?: boolean;
extra?: string[];
onProgress?: (progress: number) => void;
}
): Promise<Partial<WhisperOutputType>> {
if (this.config.service === "local") {
return this.transcribeFromLocal(file, options);
} else if (this.config.service === "azure") {
return this.transcribeFromAzure(file);
} else if (this.config.service === "cloudflare") {
return this.transcribeFromCloudflare(file);
} else {
throw new Error("Unknown service");
}
}
async transcribeFromAzure(file: string): Promise<Partial<WhisperOutputType>> {
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
const { token, region } = await webApi.generateSpeechToken();
const sdk = new AzureSpeechSdk(token, region);
const results = await sdk.transcribe({
filePath: file,
});
const transcription: TranscriptionResultSegmentType[] = [];
results.forEach((result) => {
logger.debug(result);
const best = take(sortedUniqBy(result.NBest, "Confidence"), 1)[0];
const words = best.Display.trim().split(" ");
best.Words.map((word, index) => {
let text = word.Word;
if (words.length === best.Words.length) {
text = words[index];
}
if (
index === best.Words.length - 1 &&
!text.trim().match(END_OF_WORD_REGEX)
) {
text = text + ".";
}
transcription.push({
offsets: {
from: word.Offset / 1e4,
to: (word.Offset + word.Duration) / 1e4,
},
timestamps: {
from: milisecondsToTimestamp(word.Offset / 1e4),
to: milisecondsToTimestamp((word.Offset + word.Duration) * 1e4),
},
text,
});
});
});
return {
engine: "azure",
model: {
type: "Azure AI Speech",
},
transcription,
};
}
async transcribeFromCloudflare(
file: string
): Promise<Partial<WhisperOutputType>> {
logger.debug("transcribing from CloudFlare");
const data = fs.readFileSync(file);
const res: CfWhipserOutputType = (
await axios.postForm(`${AI_WORKER_ENDPOINT}/audio/transcriptions`, data, {
headers: {
Authorization: `Bearer ${settings.getSync("user.accessToken")}`,
},
})
).data;
logger.debug("transcription from Web,", res);
const transcription: TranscriptionResultSegmentType[] = res.words.map(
(word) => {
return {
offsets: {
from: word.start * 1000,
to: word.end * 1000,
},
timestamps: {
from: milisecondsToTimestamp(word.start * 1000),
to: milisecondsToTimestamp(word.end * 1000),
},
text: word.word,
};
}
);
logger.debug("converted transcription,", transcription);
return {
engine: "cloudflare",
model: {
type: "@cf/openai/whisper",
},
transcription,
};
}
/* Ensure the file is in wav format
* and 16kHz sample rate
*/
async transcribeFromLocal(
file: string,
async transcribe(
params: {
file?: string;
blob?: {
type: string;
arrayBuffer: ArrayBuffer;
};
},
options?: {
force?: boolean;
extra?: string[];
@@ -315,6 +170,28 @@ class Whipser {
}
): Promise<Partial<WhisperOutputType>> {
logger.debug("transcribing from local");
const { blob } = params;
let { file } = params;
if (!file && !blob) {
throw new Error("No file or blob provided");
}
if (!this.currentModel()) {
throw new Error(t("pleaseDownloadWhisperModelFirst"));
}
if (blob) {
const format = blob.type.split("/")[1];
if (format !== "wav") {
throw new Error("Only wav format is supported");
}
file = path.join(settings.cachePath(), `${Date.now()}.${format}`);
await fs.outputFile(file, Buffer.from(blob.arrayBuffer));
}
const { force = false, extra = [], onProgress } = options || {};
const filename = path.basename(file, path.extname(file));
const tmpDir = settings.cachePath();
@@ -326,46 +203,35 @@ class Whipser {
return fs.readJson(outputFile);
}
if (!this.currentModel()) {
throw new Error(t("pleaseDownloadWhisperModelFirst"));
}
const command = [
`"${this.binMain}"`,
`--file "${file}"`,
`--model "${this.currentModel()}"`,
const commandArguments = [
"--file",
file,
"--model",
this.currentModel(),
"--output-json",
`--output-file "${path.join(tmpDir, filename)}"`,
"--output-file",
path.join(tmpDir, filename),
"-pp",
"--split-on-word",
"--max-len",
"1",
...extra,
].join(" ");
];
logger.info(`Running command: ${command}`);
const transcribe = spawn(
this.binMain,
[
"--file",
file,
"--model",
this.currentModel(),
"--output-json",
"--output-file",
path.join(tmpDir, filename),
"-pp",
...extra,
],
{
timeout: PROCESS_TIMEOUT,
}
logger.info(
`Running command: ${this.binMain} ${commandArguments.join(" ")}`
);
const command = spawn(this.binMain, commandArguments, {
timeout: PROCESS_TIMEOUT,
});
return new Promise((resolve, reject) => {
transcribe.stdout.on("data", (data) => {
command.stdout.on("data", (data) => {
logger.debug(`stdout: ${data}`);
});
transcribe.stderr.on("data", (data) => {
command.stderr.on("data", (data) => {
const output = data.toString();
logger.error(`stderr: ${output}`);
if (output.startsWith("whisper_print_progress_callback")) {
@@ -374,16 +240,16 @@ class Whipser {
}
});
transcribe.on("exit", (code) => {
command.on("exit", (code) => {
logger.info(`transcribe process exited with code ${code}`);
});
transcribe.on("error", (err) => {
command.on("error", (err) => {
logger.error("transcribe error", err.message);
reject(err);
});
transcribe.on("close", () => {
command.on("close", () => {
if (fs.pathExistsSync(outputFile)) {
resolve(fs.readJson(outputFile));
} else {
@@ -393,57 +259,6 @@ class Whipser {
});
}
groupTranscription(
transcription: TranscriptionResultSegmentType[]
): TranscriptionResultSegmentGroupType[] {
const generateGroup = (group?: TranscriptionResultSegmentType[]) => {
if (!group || group.length === 0) return;
const firstWord = group[0];
const lastWord = group[group.length - 1];
return {
offsets: {
from: firstWord.offsets.from,
to: lastWord.offsets.to,
},
text: group.map((w) => w.text.trim()).join(" "),
timestamps: {
from: firstWord.timestamps.from,
to: lastWord.timestamps.to,
},
segments: group,
};
};
const groups: TranscriptionResultSegmentGroupType[] = [];
let group: TranscriptionResultSegmentType[] = [];
transcription.forEach((segment) => {
const text = segment.text.trim();
if (!text) return;
group.push(segment);
if (
!MAGIC_TOKENS.includes(text) &&
segment.text.trim().match(END_OF_WORD_REGEX)
) {
// Group a complete sentence;
groups.push(generateGroup(group));
// init a new group
group = [];
}
});
// Group the last group
const lastSentence = generateGroup(group);
if (lastSentence) groups.push(lastSentence);
return groups;
}
registerIpcHandlers() {
ipcMain.handle("whisper-config", async () => {
try {
@@ -489,7 +304,7 @@ class Whipser {
message: err.message,
});
}
} else if (["cloudflare", "azure"].includes(service)) {
} else if (["cloudflare", "azure", "openai"].includes(service)) {
settings.setSync("whisper.service", service);
this.config.service = service;
return this.config;
@@ -505,9 +320,14 @@ class Whipser {
return await this.check();
});
ipcMain.handle("whisper-transcribe-blob", async (event, blob, prompt) => {
ipcMain.handle("whisper-transcribe", async (event, params, options) => {
try {
return await this.transcribeBlob(blob, prompt);
return await this.transcribe(params, {
...options,
onProgress: (progress) => {
event.sender.send("whisper-on-progress", progress);
},
});
} catch (err) {
event.sender.send("on-notification", {
type: "error",

View File

@@ -363,11 +363,26 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", {
check: () => {
return ipcRenderer.invoke("whisper-check");
},
transcribeBlob: (
blob: { type: string; arrayBuffer: ArrayBuffer },
prompt?: string
transcribe: (
params: {
file?: string;
blob?: {
type: string;
arrayBuffer: ArrayBuffer;
};
},
options?: {
force?: boolean;
extra?: string[];
}
) => {
return ipcRenderer.invoke("whisper-transcribe-blob", blob, prompt);
return ipcRenderer.invoke("whisper-transcribe", params, options);
},
onProgress: (
callback: (event: IpcRendererEvent, progress: number) => void
) => ipcRenderer.on("whisper-on-progress", callback),
removeProgressListeners: () => {
ipcRenderer.removeAllListeners("whisper-on-progress");
},
},
ffmpeg: {
@@ -425,18 +440,9 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", {
findOrCreate: (params: any) => {
return ipcRenderer.invoke("transcriptions-find-or-create", params);
},
process: (params: any, options: any) => {
return ipcRenderer.invoke("transcriptions-process", params, options);
},
update: (id: string, params: any) => {
return ipcRenderer.invoke("transcriptions-update", id, params);
},
onProgress: (
callback: (event: IpcRendererEvent, progress: number) => void
) => ipcRenderer.on("transcription-on-progress", callback),
removeProgressListeners: () => {
ipcRenderer.removeAllListeners("transcription-on-progress");
},
},
waveforms: {
find: (id: string) => {

View File

@@ -12,6 +12,7 @@ import {
ScrollArea,
Button,
PingPoint,
toast,
} from "@renderer/components/ui";
import React, { useEffect, useContext, useState } from "react";
import { t } from "i18next";
@@ -19,6 +20,7 @@ import { LoaderIcon, CheckCircleIcon, MicIcon } from "lucide-react";
import {
DbProviderContext,
AppSettingsProviderContext,
AISettingsProviderContext,
} from "@renderer/context";
import { useTranscribe } from "@renderer/hooks";
@@ -32,6 +34,7 @@ export const MediaTranscription = (props: {
onSelectSegment?: (index: number) => void;
}) => {
const { addDblistener, removeDbListener } = useContext(DbProviderContext);
const { whisperConfig } = useContext(AISettingsProviderContext);
const { EnjoyApp } = useContext(AppSettingsProviderContext);
const {
transcription,
@@ -55,13 +58,19 @@ export const MediaTranscription = (props: {
setTranscribing(true);
setProgress(0);
transcribe({
mediaId,
mediaType,
mediaSrc: mediaUrl,
}).finally(() => {
setTranscribing(false);
});
try {
const { engine, model, result } = await transcribe(mediaUrl);
await EnjoyApp.transcriptions.update(transcription.id, {
state: "finished",
result,
engine,
model,
});
} catch (err) {
toast.error(err.message);
}
setTranscribing(false);
};
const fetchSegmentStats = async () => {
@@ -80,14 +89,16 @@ export const MediaTranscription = (props: {
generate();
}
EnjoyApp.transcriptions.onProgress((_, p: number) => {
if (p > 100) p = 100;
setProgress(p);
});
if (whisperConfig.service === "local") {
EnjoyApp.whisper.onProgress((_, p: number) => {
if (p > 100) p = 100;
setProgress(p);
});
}
return () => {
removeDbListener(fetchSegmentStats);
EnjoyApp.transcriptions.removeProgressListeners();
EnjoyApp.whisper.removeProgressListeners();
};
}, [mediaId, mediaType]);
@@ -114,7 +125,9 @@ export const MediaTranscription = (props: {
{transcribing || transcription.state === "processing" ? (
<>
<PingPoint colorClassName="bg-yellow-500" />
<div className="text-sm">{progress}%</div>
<div className="text-sm">
{whisperConfig.service === "local" && `${progress}%`}
</div>
</>
) : transcription.state === "finished" ? (
<CheckCircleIcon className="text-green-500 w-4 h-4" />

View File

@@ -77,6 +77,8 @@ export const WhisperSettings = () => {
t("azureSpeechToTextDescription")}
{whisperConfig?.service === "cloudflare" &&
t("cloudflareSpeechToTextDescription")}
{whisperConfig?.service === "openai" &&
t("openaiSpeechToTextDescription")}
</div>
</div>
@@ -94,6 +96,7 @@ export const WhisperSettings = () => {
<SelectItem value="local">{t("local")}</SelectItem>
<SelectItem value="azure">{t("azureAi")}</SelectItem>
<SelectItem value="cloudflare">{t("cloudflareAi")}</SelectItem>
<SelectItem value="openai">OpenAI</SelectItem>
</SelectContent>
</Select>

View File

@@ -1,3 +1,3 @@
export * from './use-transcode';
export * from './use-transcribe';
export * from './use-ai-command';
export * from './use-conversation';

View File

@@ -1,58 +0,0 @@
import { AppSettingsProviderContext } from "@renderer/context";
import { useContext } from "react";
import { toast } from "@renderer/components/ui";
import { t } from "i18next";
import { fetchFile } from "@ffmpeg/util";
export const useTranscribe = () => {
const { EnjoyApp, ffmpeg } = useContext(AppSettingsProviderContext);
const transcode = async (src: string, options?: string[]) => {
if (!ffmpeg?.loaded) return;
options = options || ["-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le"];
try {
const uri = new URL(src);
const input = uri.pathname.split("/").pop();
const output = input.replace(/\.[^/.]+$/, ".wav");
await ffmpeg.writeFile(input, await fetchFile(src));
await ffmpeg.exec(["-i", input, ...options, output]);
const data = await ffmpeg.readFile(output);
return new Blob([data], { type: "audio/wav" });
} catch (e) {
toast.error(t("transcodeError"));
}
};
const transcribe = async (params: {
mediaSrc: string;
mediaId: string;
mediaType: "Audio" | "Video";
}) => {
const { mediaSrc, mediaId, mediaType } = params;
const data = await transcode(mediaSrc);
let blob;
if (data) {
blob = {
type: data.type.split(";")[0],
arrayBuffer: await data.arrayBuffer(),
};
}
return EnjoyApp.transcriptions.process(
{
targetId: mediaId,
targetType: mediaType,
},
{
blob,
}
);
};
return {
transcode,
transcribe,
};
};

View File

@@ -0,0 +1,263 @@
import {
AppSettingsProviderContext,
AISettingsProviderContext,
} from "@renderer/context";
import OpenAI from "openai";
import { useContext } from "react";
import { toast } from "@renderer/components/ui";
import { t } from "i18next";
import { fetchFile } from "@ffmpeg/util";
import { AI_WORKER_ENDPOINT } from "@/constants";
import * as sdk from "microsoft-cognitiveservices-speech-sdk";
import axios from "axios";
import take from "lodash/take";
import sortedUniqBy from "lodash/sortedUniqBy";
import { groupTranscription, END_OF_WORD_REGEX, milisecondsToTimestamp } from "@/utils";
export const useTranscribe = () => {
const { EnjoyApp, ffmpeg, user, webApi } = useContext(
AppSettingsProviderContext
);
const { whisperConfig, openai } = useContext(AISettingsProviderContext);
const transcode = async (src: string, options?: string[]) => {
if (!ffmpeg?.loaded) return;
options = options || ["-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le"];
try {
const uri = new URL(src);
const input = uri.pathname.split("/").pop();
const output = input.replace(/\.[^/.]+$/, ".wav");
await ffmpeg.writeFile(input, await fetchFile(src));
await ffmpeg.exec(["-i", input, ...options, output]);
const data = await ffmpeg.readFile(output);
return new Blob([data], { type: "audio/wav" });
} catch (e) {
toast.error(t("transcodeError"));
}
};
const transcribe = async (
mediaSrc: string
): Promise<{
engine: string;
model: string;
result: TranscriptionResultSegmentGroupType[];
}> => {
const blob = await transcode(mediaSrc);
if (whisperConfig.service === "local") {
return transcribeByLocal(blob);
} else if (whisperConfig.service === "cloudflare") {
return transcribeByCloudflareAi(blob);
} else if (whisperConfig.service === "openai") {
return transcribeByOpenAi(blob);
} else if (whisperConfig.service === "azure") {
return transcribeByAzureAi(blob);
} else {
throw new Error(t("whisperServiceNotSupported"));
}
};
const transcribeByLocal = async (blob: Blob) => {
const res = await EnjoyApp.whisper.transcribe(
{
blob: {
type: blob.type.split(";")[0],
arrayBuffer: await blob.arrayBuffer(),
},
},
{
force: true,
extra: ["--prompt", `"Hello! Welcome to listen to this audio."`],
}
);
const result = groupTranscription(res.transcription);
return {
engine: "whisper",
model: res.model.type,
result,
};
};
const transcribeByOpenAi = async (blob: Blob) => {
if (!openai?.key) {
throw new Error(t("openaiKeyRequired"));
}
const client = new OpenAI({
apiKey: openai.key,
baseURL: openai.baseUrl,
dangerouslyAllowBrowser: true,
});
const res: {
words: {
word: string;
start: number;
end: number;
}[];
} = (await client.audio.transcriptions.create({
file: new File([blob], "audio.wav"),
model: "whisper-1",
response_format: "verbose_json",
timestamp_granularities: ["word"],
})) as any;
const transcription: TranscriptionResultSegmentType[] = res.words.map(
(word) => {
return {
offsets: {
from: word.start * 1000,
to: word.end * 1000,
},
timestamps: {
from: milisecondsToTimestamp(word.start * 1000),
to: milisecondsToTimestamp(word.end * 1000),
},
text: word.word,
};
}
);
const result = groupTranscription(transcription);
return {
engine: "openai",
model: "whisper-1",
result,
};
};
const transcribeByCloudflareAi = async (blob: Blob) => {
const res: CfWhipserOutputType = (
await axios.postForm(`${AI_WORKER_ENDPOINT}/audio/transcriptions`, blob, {
headers: {
Authorization: `Bearer ${user.accessToken}`,
},
timeout: 1000 * 60 * 5,
})
).data;
const transcription: TranscriptionResultSegmentType[] = res.words.map(
(word) => {
return {
offsets: {
from: word.start * 1000,
to: word.end * 1000,
},
timestamps: {
from: milisecondsToTimestamp(word.start * 1000),
to: milisecondsToTimestamp(word.end * 1000),
},
text: word.word,
};
}
);
const result = groupTranscription(transcription);
return {
engine: "cloudflare",
model: "@cf/openai/whisper",
result,
};
};
const transcribeByAzureAi = async (
blob: Blob
): Promise<{
engine: string;
model: string;
result: TranscriptionResultSegmentGroupType[];
}> => {
const { token, region } = await webApi.generateSpeechToken();
const config = sdk.SpeechConfig.fromAuthorizationToken(token, region);
const audioConfig = sdk.AudioConfig.fromWavFileInput(
new File([blob], "audio.wav")
);
// setting the recognition language to English.
config.speechRecognitionLanguage = "en-US";
config.requestWordLevelTimestamps();
config.outputFormat = sdk.OutputFormat.Detailed;
// create the speech recognizer.
const reco = new sdk.SpeechRecognizer(config, audioConfig);
let results: SpeechRecognitionResultType[] = [];
return new Promise((resolve, reject) => {
reco.recognizing = (_s, e) => {
console.log(e.result.text);
};
reco.recognized = (_s, e) => {
const json = e.result.properties.getProperty(
sdk.PropertyId.SpeechServiceResponse_JsonResult
);
const result = JSON.parse(json);
results = results.concat(result);
};
reco.canceled = (_s, e) => {
if (e.reason === sdk.CancellationReason.Error) {
return reject(new Error(e.errorDetails));
}
reco.stopContinuousRecognitionAsync();
};
reco.sessionStopped = (_s, _e) => {
reco.stopContinuousRecognitionAsync();
const transcription: TranscriptionResultSegmentType[] = [];
results.forEach((result) => {
const best = take(sortedUniqBy(result.NBest, "Confidence"), 1)[0];
const words = best.Display.trim().split(" ");
best.Words.map((word, index) => {
let text = word.Word;
if (words.length === best.Words.length) {
text = words[index];
}
if (
index === best.Words.length - 1 &&
!text.trim().match(END_OF_WORD_REGEX)
) {
text = text + ".";
}
transcription.push({
offsets: {
from: word.Offset / 1e4,
to: (word.Offset + word.Duration) / 1e4,
},
timestamps: {
from: milisecondsToTimestamp(word.Offset / 1e4),
to: milisecondsToTimestamp((word.Offset + word.Duration) * 1e4),
},
text,
});
});
});
resolve({
engine: "azure",
model: "whisper",
result: groupTranscription(transcription),
});
};
reco.startContinuousRecognitionAsync();
});
};
return {
transcode,
transcribe,
};
};

View File

@@ -213,10 +213,18 @@ type EnjoyAppType = {
setService: (
service: WhisperConfigType["service"]
) => Promise<WhisperConfigType>;
transcribeBlob: (
blob: { type: string; arrayBuffer: ArrayBuffer },
prompt?: string
) => Promise<{ file: string; content: string }>;
transcribe: (
params: {
file?: string;
blob?: { type: string; arrayBuffer: ArrayBuffer };
},
options?: {
force?: boolean;
extra?: string[];
}
) => Promise<Partial<WhisperOutputType>>;
onProgress: (callback: (event, progress: number) => void) => void;
removeProgressListeners: () => Promise<void>;
};
ffmpeg: {
config: () => Promise<FfmpegConfigType>;
@@ -245,10 +253,7 @@ type EnjoyAppType = {
};
transcriptions: {
findOrCreate: (params: any) => Promise<TranscriptionType>;
process: (params: any, options: any) => Promise<void>;
update: (id: string, params: any) => Promise<void>;
onProgress: (callback: (event, progress: number) => void) => void;
removeProgressListeners: () => Promise<void>;
};
waveforms: {
find: (id: string) => Promise<WaveFormDataType>;

View File

@@ -27,7 +27,7 @@ type NotificationType = {
};
type WhisperConfigType = {
service: "local" | "azure" | "cloudflare";
service: "local" | "azure" | "cloudflare" | "openai";
availableModels: {
type: string;
name: string;

View File

@@ -1,43 +1,5 @@
import { createHash } from "crypto";
import { createReadStream } from "fs";
import Pitchfinder from "pitchfinder";
export function hashFile(
path: string,
options: { algo: string }
): Promise<string> {
const algo = options.algo || "md5";
return new Promise((resolve, reject) => {
const hash = createHash(algo);
const stream = createReadStream(path);
stream.on("error", reject);
stream.on("data", (chunk) => hash.update(chunk));
stream.on("end", () => resolve(hash.digest("hex")));
});
}
export function hashBlob(
blob: Blob,
options: { algo: string }
): Promise<string> {
const algo = options.algo || "md5";
return new Promise((resolve, reject) => {
const hash = createHash(algo);
const reader = new FileReader();
reader.onload = () => {
if (reader.result instanceof ArrayBuffer) {
const buffer = Buffer.from(reader.result);
hash.update(buffer);
resolve(hash.digest("hex"));
} else {
reject(new Error("Unexpected result from FileReader"));
}
};
reader.onerror = reject;
reader.readAsArrayBuffer(blob);
});
}
export function generatePitch(peaks: Float32Array, sampleRate: number) {
const detectPitch = Pitchfinder.YIN({ sampleRate });
const duration = peaks.length / sampleRate;
@@ -77,3 +39,56 @@ export function milisecondsToTimestamp(ms: number) {
"0"
)}:${seconds.padStart(2, "0")},${milliseconds}`;
}
export const MAGIC_TOKENS = ["Mrs.", "Ms.", "Mr.", "Dr.", "Prof.", "St."];
export const END_OF_WORD_REGEX = /[^\.!,\?][\.!\?]/g;
export const groupTranscription = (
transcription: TranscriptionResultSegmentType[]
): TranscriptionResultSegmentGroupType[] => {
const generateGroup = (group?: TranscriptionResultSegmentType[]) => {
if (!group || group.length === 0) return;
const firstWord = group[0];
const lastWord = group[group.length - 1];
return {
offsets: {
from: firstWord.offsets.from,
to: lastWord.offsets.to,
},
text: group.map((w) => w.text.trim()).join(" "),
timestamps: {
from: firstWord.timestamps.from,
to: lastWord.timestamps.to,
},
segments: group,
};
};
const groups: TranscriptionResultSegmentGroupType[] = [];
let group: TranscriptionResultSegmentType[] = [];
transcription.forEach((segment) => {
const text = segment.text.trim();
if (!text) return;
group.push(segment);
if (
!MAGIC_TOKENS.includes(text) &&
segment.text.trim().match(END_OF_WORD_REGEX)
) {
// Group a complete sentence;
groups.push(generateGroup(group));
// init a new group
group = [];
}
});
// Group the last group
const lastSentence = generateGroup(group);
if (lastSentence) groups.push(lastSentence);
return groups;
};