Refactor whisper config (#287)

* may use custom whisper exe

* refactor whisper config code

* display progress when using local whisper transcribe
This commit is contained in:
an-lee
2024-02-09 17:07:21 +08:00
committed by GitHub
parent 5eafd45ac5
commit a95c247c8c
7 changed files with 110 additions and 41 deletions

View File

@@ -94,7 +94,13 @@ class TranscriptionsHandler {
}, 1000 * 10);
transcription
.process({ force, wavFileBlob: blob })
.process({
force,
wavFileBlob: blob,
onProgress: (progress: number) => {
event.sender.send("transcription-on-progress", progress);
},
})
.catch((err) => {
event.sender.send("on-notification", {
type: "error",

View File

@@ -92,11 +92,12 @@ export class Transcription extends Model<Transcription> {
options: {
force?: boolean;
wavFileBlob?: { type: string; arrayBuffer: ArrayBuffer };
onProgress?: (progress: number) => void;
} = {}
) {
if (this.getDataValue("state") === "processing") return;
const { force = false, wavFileBlob } = options;
const { force = false, wavFileBlob, onProgress } = options;
logger.info(`[${this.getDataValue("id")}]`, "Start to transcribe.");
@@ -156,9 +157,12 @@ export class Transcription extends Model<Transcription> {
force,
extra: [
"--split-on-word",
"--max-len 1",
`--prompt "Hello! Welcome to listen to this audio."`,
"--max-len",
"1",
"--prompt",
`"Hello! Welcome to listen to this audio."`,
],
onProgress,
});
const result = whisper.groupTranscription(transcription);
this.update({

View File

@@ -5,8 +5,9 @@ import {
WHISPER_MODELS_OPTIONS,
PROCESS_TIMEOUT,
AI_WORKER_ENDPOINT,
WEB_API_URL } from "@/constants";
import { exec } from "child_process";
WEB_API_URL,
} from "@/constants";
import { exec, spawn } from "child_process";
import fs from "fs-extra";
import log from "electron-log/main";
import { t } from "i18next";
@@ -22,11 +23,21 @@ const logger = log.scope("whisper");
const MAGIC_TOKENS = ["Mrs.", "Ms.", "Mr.", "Dr.", "Prof.", "St."];
const END_OF_WORD_REGEX = /[^\.!,\?][\.!\?]/g;
class Whipser {
private binMain = path.join(__dirname, "lib", "whisper", "main");
private binMain: string;
public config: WhisperConfigType;
constructor(config?: WhisperConfigType) {
this.config = config || settings.whisperConfig();
const customWhisperPath = path.join(
settings.libraryPath(),
"whisper",
"main"
);
if (fs.existsSync(customWhisperPath)) {
this.binMain = customWhisperPath;
} else {
this.binMain = path.join(__dirname, "lib", "whisper", "main");
}
}
currentModel() {
@@ -60,9 +71,11 @@ class Whipser {
settings.setSync("whisper.modelsPath", dir);
this.config = settings.whisperConfig();
const command = `"${this.binMain}" --help`;
logger.debug(`Checking whisper command: ${command}`);
return new Promise((resolve, reject) => {
exec(
`"${this.binMain}" --help`,
command,
{
timeout: PROCESS_TIMEOUT,
},
@@ -111,7 +124,7 @@ class Whipser {
"--output-json",
`--output-file "${path.join(tmpDir, "jfk")}"`,
];
logger.debug(`Running command: ${commands.join(" ")}`);
logger.debug(`Checking whisper command: ${commands.join(" ")}`);
exec(
commands.join(" "),
{
@@ -180,6 +193,7 @@ class Whipser {
options?: {
force?: boolean;
extra?: string[];
onProgress?: (progress: number) => void;
}
): Promise<Partial<WhisperOutputType>> {
if (this.config.service === "local") {
@@ -297,10 +311,11 @@ class Whipser {
options?: {
force?: boolean;
extra?: string[];
onProgress?: (progress: number) => void;
}
): Promise<Partial<WhisperOutputType>> {
logger.debug("transcribing from local");
const { force = false, extra = [] } = options || {};
const { force = false, extra = [], onProgress } = options || {};
const filename = path.basename(file, path.extname(file));
const tmpDir = settings.cachePath();
const outputFile = path.join(tmpDir, filename + ".json");
@@ -321,39 +336,60 @@ class Whipser {
`--model "${this.currentModel()}"`,
"--output-json",
`--output-file "${path.join(tmpDir, filename)}"`,
"-pp",
...extra,
].join(" ");
logger.info(`Running command: ${command}`);
const transcribe = spawn(
this.binMain,
[
"--file",
file,
"--model",
this.currentModel(),
"--output-json",
"--output-file",
path.join(tmpDir, filename),
"-pp",
...extra,
],
{
timeout: PROCESS_TIMEOUT,
}
);
return new Promise((resolve, reject) => {
exec(
command,
{
timeout: PROCESS_TIMEOUT,
},
(error, stdout, stderr) => {
if (fs.pathExistsSync(outputFile)) {
resolve(fs.readJson(outputFile));
}
transcribe.stdout.on("data", (data) => {
logger.debug(`stdout: ${data}`);
});
if (error) {
logger.error("error", error);
}
if (stderr) {
logger.error("stderr", stderr);
}
if (stdout) {
logger.debug(stdout);
}
reject(
error ||
new Error(stderr || "Whisper transcribe failed: unknown error")
);
transcribe.stderr.on("data", (data) => {
const output = data.toString();
logger.error(`stderr: ${output}`);
if (output.startsWith("whisper_print_progress_callback")) {
const progress = parseInt(output.match(/\d+%/)?.[0] || "0");
if (typeof progress === "number") onProgress(progress);
}
);
});
transcribe.on("exit", (code) => {
logger.info(`transcribe process exited with code ${code}`);
});
transcribe.on("error", (err) => {
logger.error("transcribe error", err.message);
reject(err);
});
transcribe.on("close", () => {
if (fs.pathExistsSync(outputFile)) {
resolve(fs.readJson(outputFile));
} else {
reject(new Error("Transcription failed"));
}
});
});
}
@@ -424,8 +460,12 @@ class Whipser {
this.config = settings.whisperConfig();
return this.check()
.then(() => {
return Object.assign({}, this.config, { ready: true });
.then(({ success, log }) => {
if (success) {
return Object.assign({}, this.config, { ready: true });
} else {
throw new Error(log);
}
})
.catch((err) => {
settings.setSync("whisper.model", originalModel);

View File

@@ -428,6 +428,12 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", {
update: (id: string, params: any) => {
return ipcRenderer.invoke("transcriptions-update", id, params);
},
onProgress: (
callback: (event: IpcRendererEvent, progress: number) => void
) => ipcRenderer.on("transcription-on-progress", callback),
removeProgressListeners: () => {
ipcRenderer.removeAllListeners("transcription-on-progress");
},
},
waveforms: {
find: (id: string) => {

View File

@@ -12,7 +12,6 @@ import {
ScrollArea,
Button,
PingPoint,
toast,
} from "@renderer/components/ui";
import React, { useEffect, useContext, useState } from "react";
import { t } from "i18next";
@@ -46,6 +45,7 @@ export const MediaTranscription = (props: {
const containerRef = React.createRef<HTMLDivElement>();
const [transcribing, setTranscribing] = useState<boolean>(false);
const { transcribe } = useTranscribe();
const [progress, setProgress] = useState<number>(0);
const [recordingStats, setRecordingStats] =
useState<SegementRecordingStatsType>([]);
@@ -54,6 +54,7 @@ export const MediaTranscription = (props: {
if (transcribing) return;
setTranscribing(true);
setProgress(0);
transcribe({
mediaId,
mediaType,
@@ -79,8 +80,14 @@ export const MediaTranscription = (props: {
generate();
}
EnjoyApp.transcriptions.onProgress((_, p: number) => {
if (p > 100) p = 100;
setProgress(p);
});
return () => {
removeDbListener(fetchSegmentStats);
EnjoyApp.transcriptions.removeProgressListeners();
};
}, [mediaId, mediaType]);
@@ -105,7 +112,10 @@ export const MediaTranscription = (props: {
<div className="mb-4 flex items-cener justify-between">
<div className="flex items-center space-x-2">
{transcribing || transcription.state === "processing" ? (
<PingPoint colorClassName="bg-yellow-500" />
<>
<PingPoint colorClassName="bg-yellow-500" />
<div className="text-sm">{progress}%</div>
</>
) : transcription.state === "finished" ? (
<CheckCircleIcon className="text-green-500 w-4 h-4" />
) : (

View File

@@ -36,10 +36,11 @@ export const AISettingsProvider = ({
useEffect(() => {
fetchSettings();
refreshWhisperConfig();
}, []);
useEffect(() => {
if (!libraryPath) return;
refreshWhisperConfig();
}, [libraryPath]);

View File

@@ -246,6 +246,8 @@ type EnjoyAppType = {
findOrCreate: (params: any) => Promise<TranscriptionType>;
process: (params: any, options: any) => Promise<void>;
update: (id: string, params: any) => Promise<void>;
onProgress: (callback: (event, progress: number) => void) => void;
removeProgressListeners: () => Promise<void>;
};
waveforms: {
find: (id: string) => Promise<WaveFormDataType>;