From 071d80060deaee45aca9bbd896f9bc54d748e2cc Mon Sep 17 00:00:00 2001 From: an-lee Date: Tue, 2 Apr 2024 14:43:15 +0800 Subject: [PATCH] Fix punctuation (#477) * add punctuate command * check punctuation before alignment --- enjoy/src/commands/index.ts | 1 + enjoy/src/commands/punctuate.command.ts | 38 +++++++++++++++++++++ enjoy/src/renderer/hooks/use-ai-command.tsx | 10 ++++++ enjoy/src/renderer/hooks/use-transcribe.tsx | 14 +++++++- 4 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 enjoy/src/commands/punctuate.command.ts diff --git a/enjoy/src/commands/index.ts b/enjoy/src/commands/index.ts index 2e7f7e5c..25f8bd05 100644 --- a/enjoy/src/commands/index.ts +++ b/enjoy/src/commands/index.ts @@ -3,3 +3,4 @@ export * from "./lookup.command"; export * from "./translate.command"; export * from "./ipa.command"; export * from "./analyze.command"; +export * from "./punctuate.command"; diff --git a/enjoy/src/commands/punctuate.command.ts b/enjoy/src/commands/punctuate.command.ts new file mode 100644 index 00000000..b6847807 --- /dev/null +++ b/enjoy/src/commands/punctuate.command.ts @@ -0,0 +1,38 @@ +import { ChatOpenAI } from "@langchain/openai"; +import { ChatPromptTemplate } from "@langchain/core/prompts"; + +export const punctuateCommand = async ( + text: string, + options: { + key: string; + modelName?: string; + temperature?: number; + baseUrl?: string; + } +): Promise => { + const { key, temperature = 0, baseUrl } = options; + let { modelName = "gpt-4-turbo-preview" } = options; + + const chatModel = new ChatOpenAI({ + openAIApiKey: key, + modelName, + temperature, + configuration: { + baseURL: baseUrl, + }, + cache: false, + verbose: true, + maxRetries: 2, + }); + + const prompt = ChatPromptTemplate.fromMessages([ + ["system", SYSTEM_PROMPT], + ["human", text], + ]); + + const response = await prompt.pipe(chatModel).invoke({}); + + return response.text; +}; + +const SYSTEM_PROMPT = `Please add proper punctuation to the text I provide you. Return the corrected text only.`; diff --git a/enjoy/src/renderer/hooks/use-ai-command.tsx b/enjoy/src/renderer/hooks/use-ai-command.tsx index dc081baf..c337655e 100644 --- a/enjoy/src/renderer/hooks/use-ai-command.tsx +++ b/enjoy/src/renderer/hooks/use-ai-command.tsx @@ -8,6 +8,7 @@ import { extractStoryCommand, translateCommand, analyzeCommand, + punctuateCommand, } from "@commands"; export const useAiCommand = () => { @@ -102,10 +103,19 @@ export const useAiCommand = () => { }); }; + const punctuateText = async (text: string) => { + return punctuateCommand(text, { + key: currentEngine.key, + modelName: currentEngine.model, + baseUrl: currentEngine.baseUrl, + }) + } + return { lookupWord, extractStory, translate, analyzeText, + punctuateText }; }; diff --git a/enjoy/src/renderer/hooks/use-transcribe.tsx b/enjoy/src/renderer/hooks/use-transcribe.tsx index 669f8466..160da334 100644 --- a/enjoy/src/renderer/hooks/use-transcribe.tsx +++ b/enjoy/src/renderer/hooks/use-transcribe.tsx @@ -9,10 +9,12 @@ import { AI_WORKER_ENDPOINT } from "@/constants"; import * as sdk from "microsoft-cognitiveservices-speech-sdk"; import axios from "axios"; import { AlignmentResult } from "echogarden/dist/api/API.d.js"; +import { useAiCommand } from "./use-ai-command"; export const useTranscribe = () => { const { EnjoyApp, user, webApi } = useContext(AppSettingsProviderContext); const { whisperConfig, openai } = useContext(AISettingsProviderContext); + const { punctuateText } = useAiCommand(); const transcode = async (src: string | Blob): Promise => { if (src instanceof Blob) { @@ -61,9 +63,19 @@ export const useTranscribe = () => { throw new Error(t("whisperServiceNotSupported")); } + let transcript = originalText || result.text; + // if the transcript does not contain any punctuation, use AI command to add punctuation + if (!transcript.match(/\w[.,!?](\s|$)/)) { + try { + transcript = await punctuateText(transcript); + } catch (err) { + console.warn(err.message); + } + } + const alignmentResult = await EnjoyApp.echogarden.align( new Uint8Array(await blob.arrayBuffer()), - originalText || result.text + transcript ); return {