From a95c247c8c22eff46c9c6b7f49a35584f3741b08 Mon Sep 17 00:00:00 2001 From: an-lee Date: Fri, 9 Feb 2024 17:07:21 +0800 Subject: [PATCH] Refactor whisper config (#287) * may use custom whisper exe * refactor whisper config code * display progress when using local whisper transcribe --- .../db/handlers/transcriptions-handler.ts | 8 +- enjoy/src/main/db/models/transcription.ts | 10 +- enjoy/src/main/whisper.ts | 108 ++++++++++++------ enjoy/src/preload.ts | 6 + .../components/medias/media-transcription.tsx | 14 ++- .../renderer/context/ai-settings-provider.tsx | 3 +- enjoy/src/types/enjoy-app.d.ts | 2 + 7 files changed, 110 insertions(+), 41 deletions(-) diff --git a/enjoy/src/main/db/handlers/transcriptions-handler.ts b/enjoy/src/main/db/handlers/transcriptions-handler.ts index 2621a3f1..b4eeda94 100644 --- a/enjoy/src/main/db/handlers/transcriptions-handler.ts +++ b/enjoy/src/main/db/handlers/transcriptions-handler.ts @@ -94,7 +94,13 @@ class TranscriptionsHandler { }, 1000 * 10); transcription - .process({ force, wavFileBlob: blob }) + .process({ + force, + wavFileBlob: blob, + onProgress: (progress: number) => { + event.sender.send("transcription-on-progress", progress); + }, + }) .catch((err) => { event.sender.send("on-notification", { type: "error", diff --git a/enjoy/src/main/db/models/transcription.ts b/enjoy/src/main/db/models/transcription.ts index 4124bce8..b7d7f213 100644 --- a/enjoy/src/main/db/models/transcription.ts +++ b/enjoy/src/main/db/models/transcription.ts @@ -92,11 +92,12 @@ export class Transcription extends Model { options: { force?: boolean; wavFileBlob?: { type: string; arrayBuffer: ArrayBuffer }; + onProgress?: (progress: number) => void; } = {} ) { if (this.getDataValue("state") === "processing") return; - const { force = false, wavFileBlob } = options; + const { force = false, wavFileBlob, onProgress } = options; logger.info(`[${this.getDataValue("id")}]`, "Start to transcribe."); @@ -156,9 +157,12 @@ export class Transcription extends Model { force, extra: [ "--split-on-word", - "--max-len 1", - `--prompt "Hello! Welcome to listen to this audio."`, + "--max-len", + "1", + "--prompt", + `"Hello! Welcome to listen to this audio."`, ], + onProgress, }); const result = whisper.groupTranscription(transcription); this.update({ diff --git a/enjoy/src/main/whisper.ts b/enjoy/src/main/whisper.ts index e54fd5a1..0d422986 100644 --- a/enjoy/src/main/whisper.ts +++ b/enjoy/src/main/whisper.ts @@ -5,8 +5,9 @@ import { WHISPER_MODELS_OPTIONS, PROCESS_TIMEOUT, AI_WORKER_ENDPOINT, - WEB_API_URL } from "@/constants"; -import { exec } from "child_process"; + WEB_API_URL, +} from "@/constants"; +import { exec, spawn } from "child_process"; import fs from "fs-extra"; import log from "electron-log/main"; import { t } from "i18next"; @@ -22,11 +23,21 @@ const logger = log.scope("whisper"); const MAGIC_TOKENS = ["Mrs.", "Ms.", "Mr.", "Dr.", "Prof.", "St."]; const END_OF_WORD_REGEX = /[^\.!,\?][\.!\?]/g; class Whipser { - private binMain = path.join(__dirname, "lib", "whisper", "main"); + private binMain: string; public config: WhisperConfigType; constructor(config?: WhisperConfigType) { this.config = config || settings.whisperConfig(); + const customWhisperPath = path.join( + settings.libraryPath(), + "whisper", + "main" + ); + if (fs.existsSync(customWhisperPath)) { + this.binMain = customWhisperPath; + } else { + this.binMain = path.join(__dirname, "lib", "whisper", "main"); + } } currentModel() { @@ -60,9 +71,11 @@ class Whipser { settings.setSync("whisper.modelsPath", dir); this.config = settings.whisperConfig(); + const command = `"${this.binMain}" --help`; + logger.debug(`Checking whisper command: ${command}`); return new Promise((resolve, reject) => { exec( - `"${this.binMain}" --help`, + command, { timeout: PROCESS_TIMEOUT, }, @@ -111,7 +124,7 @@ class Whipser { "--output-json", `--output-file "${path.join(tmpDir, "jfk")}"`, ]; - logger.debug(`Running command: ${commands.join(" ")}`); + logger.debug(`Checking whisper command: ${commands.join(" ")}`); exec( commands.join(" "), { @@ -180,6 +193,7 @@ class Whipser { options?: { force?: boolean; extra?: string[]; + onProgress?: (progress: number) => void; } ): Promise> { if (this.config.service === "local") { @@ -297,10 +311,11 @@ class Whipser { options?: { force?: boolean; extra?: string[]; + onProgress?: (progress: number) => void; } ): Promise> { logger.debug("transcribing from local"); - const { force = false, extra = [] } = options || {}; + const { force = false, extra = [], onProgress } = options || {}; const filename = path.basename(file, path.extname(file)); const tmpDir = settings.cachePath(); const outputFile = path.join(tmpDir, filename + ".json"); @@ -321,39 +336,60 @@ class Whipser { `--model "${this.currentModel()}"`, "--output-json", `--output-file "${path.join(tmpDir, filename)}"`, + "-pp", ...extra, ].join(" "); logger.info(`Running command: ${command}`); + + const transcribe = spawn( + this.binMain, + [ + "--file", + file, + "--model", + this.currentModel(), + "--output-json", + "--output-file", + path.join(tmpDir, filename), + "-pp", + ...extra, + ], + { + timeout: PROCESS_TIMEOUT, + } + ); + return new Promise((resolve, reject) => { - exec( - command, - { - timeout: PROCESS_TIMEOUT, - }, - (error, stdout, stderr) => { - if (fs.pathExistsSync(outputFile)) { - resolve(fs.readJson(outputFile)); - } + transcribe.stdout.on("data", (data) => { + logger.debug(`stdout: ${data}`); + }); - if (error) { - logger.error("error", error); - } - - if (stderr) { - logger.error("stderr", stderr); - } - - if (stdout) { - logger.debug(stdout); - } - - reject( - error || - new Error(stderr || "Whisper transcribe failed: unknown error") - ); + transcribe.stderr.on("data", (data) => { + const output = data.toString(); + logger.error(`stderr: ${output}`); + if (output.startsWith("whisper_print_progress_callback")) { + const progress = parseInt(output.match(/\d+%/)?.[0] || "0"); + if (typeof progress === "number") onProgress(progress); } - ); + }); + + transcribe.on("exit", (code) => { + logger.info(`transcribe process exited with code ${code}`); + }); + + transcribe.on("error", (err) => { + logger.error("transcribe error", err.message); + reject(err); + }); + + transcribe.on("close", () => { + if (fs.pathExistsSync(outputFile)) { + resolve(fs.readJson(outputFile)); + } else { + reject(new Error("Transcription failed")); + } + }); }); } @@ -424,8 +460,12 @@ class Whipser { this.config = settings.whisperConfig(); return this.check() - .then(() => { - return Object.assign({}, this.config, { ready: true }); + .then(({ success, log }) => { + if (success) { + return Object.assign({}, this.config, { ready: true }); + } else { + throw new Error(log); + } }) .catch((err) => { settings.setSync("whisper.model", originalModel); diff --git a/enjoy/src/preload.ts b/enjoy/src/preload.ts index f0fa0cd6..0a46ea17 100644 --- a/enjoy/src/preload.ts +++ b/enjoy/src/preload.ts @@ -428,6 +428,12 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", { update: (id: string, params: any) => { return ipcRenderer.invoke("transcriptions-update", id, params); }, + onProgress: ( + callback: (event: IpcRendererEvent, progress: number) => void + ) => ipcRenderer.on("transcription-on-progress", callback), + removeProgressListeners: () => { + ipcRenderer.removeAllListeners("transcription-on-progress"); + }, }, waveforms: { find: (id: string) => { diff --git a/enjoy/src/renderer/components/medias/media-transcription.tsx b/enjoy/src/renderer/components/medias/media-transcription.tsx index e6eddf56..e2c5889c 100644 --- a/enjoy/src/renderer/components/medias/media-transcription.tsx +++ b/enjoy/src/renderer/components/medias/media-transcription.tsx @@ -12,7 +12,6 @@ import { ScrollArea, Button, PingPoint, - toast, } from "@renderer/components/ui"; import React, { useEffect, useContext, useState } from "react"; import { t } from "i18next"; @@ -46,6 +45,7 @@ export const MediaTranscription = (props: { const containerRef = React.createRef(); const [transcribing, setTranscribing] = useState(false); const { transcribe } = useTranscribe(); + const [progress, setProgress] = useState(0); const [recordingStats, setRecordingStats] = useState([]); @@ -54,6 +54,7 @@ export const MediaTranscription = (props: { if (transcribing) return; setTranscribing(true); + setProgress(0); transcribe({ mediaId, mediaType, @@ -79,8 +80,14 @@ export const MediaTranscription = (props: { generate(); } + EnjoyApp.transcriptions.onProgress((_, p: number) => { + if (p > 100) p = 100; + setProgress(p); + }); + return () => { removeDbListener(fetchSegmentStats); + EnjoyApp.transcriptions.removeProgressListeners(); }; }, [mediaId, mediaType]); @@ -105,7 +112,10 @@ export const MediaTranscription = (props: {
{transcribing || transcription.state === "processing" ? ( - + <> + +
{progress}%
+ ) : transcription.state === "finished" ? ( ) : ( diff --git a/enjoy/src/renderer/context/ai-settings-provider.tsx b/enjoy/src/renderer/context/ai-settings-provider.tsx index 40f19af9..8656dbc1 100644 --- a/enjoy/src/renderer/context/ai-settings-provider.tsx +++ b/enjoy/src/renderer/context/ai-settings-provider.tsx @@ -36,10 +36,11 @@ export const AISettingsProvider = ({ useEffect(() => { fetchSettings(); - refreshWhisperConfig(); }, []); useEffect(() => { + if (!libraryPath) return; + refreshWhisperConfig(); }, [libraryPath]); diff --git a/enjoy/src/types/enjoy-app.d.ts b/enjoy/src/types/enjoy-app.d.ts index af4b9af8..851d89a5 100644 --- a/enjoy/src/types/enjoy-app.d.ts +++ b/enjoy/src/types/enjoy-app.d.ts @@ -246,6 +246,8 @@ type EnjoyAppType = { findOrCreate: (params: any) => Promise; process: (params: any, options: any) => Promise; update: (id: string, params: any) => Promise; + onProgress: (callback: (event, progress: number) => void) => void; + removeProgressListeners: () => Promise; }; waveforms: { find: (id: string) => Promise;