From abde169ead52e68c4e3cd419acae1f21e0bc74d9 Mon Sep 17 00:00:00 2001 From: an-lee Date: Fri, 2 Feb 2024 00:41:23 +0800 Subject: [PATCH] Fix openai proxy (#244) * add create messages in batch * add use conversation * update conversation shortcut * add speech handler * tts in renderer * fix speech create --- enjoy/src/main/db/handlers/index.ts | 1 + .../src/main/db/handlers/messages-handler.ts | 23 +- .../src/main/db/handlers/speeches-handler.ts | 50 ++++ enjoy/src/main/db/index.ts | 2 + enjoy/src/main/db/models/conversation.ts | 2 +- enjoy/src/main/db/models/speech.ts | 2 +- enjoy/src/preload.ts | 20 ++ .../conversations/conversations-shortcut.tsx | 10 +- .../components/messages/assistant-message.tsx | 15 +- enjoy/src/renderer/hooks/index.ts | 1 + enjoy/src/renderer/hooks/useConversation.tsx | 213 ++++++++++++++++++ enjoy/src/renderer/pages/conversation.tsx | 15 +- enjoy/src/types/enjoy-app.d.ts | 19 ++ enjoy/src/types/message.d.ts | 11 +- enjoy/src/types/speech.d.ts | 1 + 15 files changed, 358 insertions(+), 27 deletions(-) create mode 100644 enjoy/src/main/db/handlers/speeches-handler.ts create mode 100644 enjoy/src/renderer/hooks/useConversation.tsx diff --git a/enjoy/src/main/db/handlers/index.ts b/enjoy/src/main/db/handlers/index.ts index 0aaf02db..d3a994a2 100644 --- a/enjoy/src/main/db/handlers/index.ts +++ b/enjoy/src/main/db/handlers/index.ts @@ -3,5 +3,6 @@ export * from './recordings-handler'; export * from './messages-handler'; export * from './conversations-handler'; export * from './cache-objects-handler'; +export * from './speeches-handler'; export * from './transcriptions-handler'; export * from './videos-handler'; diff --git a/enjoy/src/main/db/handlers/messages-handler.ts b/enjoy/src/main/db/handlers/messages-handler.ts index 655723a3..8bbce4c1 100644 --- a/enjoy/src/main/db/handlers/messages-handler.ts +++ b/enjoy/src/main/db/handlers/messages-handler.ts @@ -3,6 +3,7 @@ import { Message, Speech, Conversation } from "@main/db/models"; import { FindOptions, WhereOptions, Attributes } from "sequelize"; import log from "electron-log/main"; import { t } from "i18next"; +import db from "@main/db"; class MessagesHandler { private async findAll( @@ -16,7 +17,7 @@ class MessagesHandler { model: Speech, where: { sourceType: "Message" }, required: false, - } + }, ], order: [["createdAt", "DESC"]], ...options, @@ -79,6 +80,25 @@ class MessagesHandler { }); } + private async createInBatch(event: IpcMainEvent, messages: Message[]) { + try { + const transaction = await db.connection.transaction(); + for (const message of messages) { + await Message.create(message, { + include: [Conversation], + transaction, + }); + } + + await transaction.commit(); + } catch (err) { + event.sender.send("on-notification", { + type: "error", + message: err.message, + }); + } + } + private async update( event: IpcMainEvent, id: string, @@ -150,6 +170,7 @@ class MessagesHandler { ipcMain.handle("messages-find-all", this.findAll); ipcMain.handle("messages-find-one", this.findOne); ipcMain.handle("messages-create", this.create); + ipcMain.handle("messages-create-in-batch", this.createInBatch); ipcMain.handle("messages-update", this.update); ipcMain.handle("messages-destroy", this.destroy); ipcMain.handle("messages-create-speech", this.createSpeech); diff --git a/enjoy/src/main/db/handlers/speeches-handler.ts b/enjoy/src/main/db/handlers/speeches-handler.ts new file mode 100644 index 00000000..93e5200b --- /dev/null +++ b/enjoy/src/main/db/handlers/speeches-handler.ts @@ -0,0 +1,50 @@ +import { ipcMain, IpcMainEvent } from "electron"; +import { Speech } from "@main/db/models"; +import fs from "fs-extra"; +import path from "path"; +import settings from "@main/settings"; +import { hashFile } from "@/utils"; + +class SpeechesHandler { + private async create( + event: IpcMainEvent, + params: { + sourceId: string; + sourceType: string; + text: string; + configuration: { + engine: string; + model: string; + voice: string; + }; + }, + blob: { + type: string; + arrayBuffer: ArrayBuffer; + } + ) { + const format = blob.type.split("/")[1]; + const filename = `${Date.now()}.${format}`; + const file = path.join(settings.userDataPath(), "speeches", filename); + await fs.outputFile(file, Buffer.from(blob.arrayBuffer)); + const md5 = await hashFile(file, { algo: "md5" }); + fs.renameSync(file, path.join(path.dirname(file), `${md5}.${format}`)); + + return Speech.create({ ...params, extname: `.${format}`, md5 }) + .then((speech) => { + return speech.toJSON(); + }) + .catch((err) => { + event.sender.send("on-notification", { + type: "error", + message: err.message, + }); + }); + } + + register() { + ipcMain.handle("speeches-create", this.create); + } +} + +export const speechesHandler = new SpeechesHandler(); diff --git a/enjoy/src/main/db/index.ts b/enjoy/src/main/db/index.ts index fe5d0c6d..b326bda9 100644 --- a/enjoy/src/main/db/index.ts +++ b/enjoy/src/main/db/index.ts @@ -19,6 +19,7 @@ import { conversationsHandler, messagesHandler, recordingsHandler, + speechesHandler, transcriptionsHandler, videosHandler, } from "./handlers"; @@ -89,6 +90,7 @@ db.connect = async () => { recordingsHandler.register(); conversationsHandler.register(); messagesHandler.register(); + speechesHandler.register(); transcriptionsHandler.register(); videosHandler.register(); }; diff --git a/enjoy/src/main/db/models/conversation.ts b/enjoy/src/main/db/models/conversation.ts index a4c46e69..b4251217 100644 --- a/enjoy/src/main/db/models/conversation.ts +++ b/enjoy/src/main/db/models/conversation.ts @@ -288,7 +288,7 @@ export class Conversation extends Model { extra: [`--prompt "${prompt}"`], }); content = transcription - .map((t: TranscriptionSegmentType) => t.text) + .map((t: TranscriptionResultSegmentType) => t.text) .join(" ") .trim(); diff --git a/enjoy/src/main/db/models/speech.ts b/enjoy/src/main/db/models/speech.ts index bedef721..b0e740a8 100644 --- a/enjoy/src/main/db/models/speech.ts +++ b/enjoy/src/main/db/models/speech.ts @@ -92,7 +92,7 @@ export class Speech extends Model { @Column(DataType.VIRTUAL) get voice(): string { - return this.getDataValue("configuration").model; + return this.getDataValue("configuration").voice; } @Column(DataType.VIRTUAL) diff --git a/enjoy/src/preload.ts b/enjoy/src/preload.ts index f0e0ee6d..f0fa0cd6 100644 --- a/enjoy/src/preload.ts +++ b/enjoy/src/preload.ts @@ -306,6 +306,9 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", { findOne: (params: object) => { return ipcRenderer.invoke("messages-find-one", params); }, + createInBatch: (messages: Partial[]) => { + return ipcRenderer.invoke("messages-create-in-batch", messages); + }, destroy: (id: string) => { return ipcRenderer.invoke("messages-destroy", id); }, @@ -313,6 +316,23 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", { return ipcRenderer.invoke("messages-create-speech", id, configuration); }, }, + speeches: { + create: ( + params: { + sourceId: string; + sourceType: string; + text: string; + configuration: { + engine: string; + model: string; + voice: string; + }; + }, + blob: { type: string; arrayBuffer: ArrayBuffer } + ) => { + return ipcRenderer.invoke("speeches-create", params, blob); + }, + }, audiowaveform: { generate: ( file: string, diff --git a/enjoy/src/renderer/components/conversations/conversations-shortcut.tsx b/enjoy/src/renderer/components/conversations/conversations-shortcut.tsx index 1fe5b6c9..bcfa2dec 100644 --- a/enjoy/src/renderer/components/conversations/conversations-shortcut.tsx +++ b/enjoy/src/renderer/components/conversations/conversations-shortcut.tsx @@ -4,16 +4,18 @@ import { Button, ScrollArea, toast } from "@renderer/components/ui"; import { LoaderSpin } from "@renderer/components"; import { MessageCircleIcon, LoaderIcon } from "lucide-react"; import { t } from "i18next"; +import { useConversation } from "@renderer/hooks"; export const ConversationsShortcut = (props: { prompt: string; - onReply?: (reply: MessageType[]) => void; + onReply?: (reply: Partial[]) => void; }) => { const { EnjoyApp } = useContext(AppSettingsProviderContext); const { prompt, onReply } = props; const [conversations, setConversations] = useState([]); const [loading, setLoading] = useState(false); const [offset, setOffset] = useState(0); + const { chat } = useConversation(); const fetchConversations = () => { if (offset === -1) return; @@ -51,10 +53,8 @@ export const ConversationsShortcut = (props: { const ask = (conversation: ConversationType) => { setLoading(true); - EnjoyApp.conversations - .ask(conversation.id, { - content: prompt, - }) + + chat({ content: prompt }, { conversation }) .then((replies) => { onReply(replies); }) diff --git a/enjoy/src/renderer/components/messages/assistant-message.tsx b/enjoy/src/renderer/components/messages/assistant-message.tsx index ae157119..17d6a5b4 100644 --- a/enjoy/src/renderer/components/messages/assistant-message.tsx +++ b/enjoy/src/renderer/components/messages/assistant-message.tsx @@ -31,6 +31,7 @@ import { useCopyToClipboard } from "@uidotdev/usehooks"; import { t } from "i18next"; import { AppSettingsProviderContext } from "@renderer/context"; import Markdown from "react-markdown"; +import { useConversation } from "@renderer/hooks"; export const AssistantMessageComponent = (props: { message: MessageType; @@ -46,6 +47,7 @@ export const AssistantMessageComponent = (props: { const [resourcing, setResourcing] = useState(false); const [shadowing, setShadowing] = useState(false); const { EnjoyApp } = useContext(AppSettingsProviderContext); + const { tts } = useConversation(); useEffect(() => { if (speech) return; @@ -59,13 +61,12 @@ export const AssistantMessageComponent = (props: { setSpeeching(true); - EnjoyApp.messages - .createSpeech(message.id, { - engine: configuration?.tts?.engine, - model: configuration?.tts?.model, - voice: configuration?.tts?.voice, - baseUrl: configuration?.tts?.baseUrl, - }) + tts({ + sourceType: "Message", + sourceId: message.id, + text: message.content, + configuration: configuration.tts, + }) .then((speech) => { setSpeech(speech); }) diff --git a/enjoy/src/renderer/hooks/index.ts b/enjoy/src/renderer/hooks/index.ts index 8a33e577..8c678842 100644 --- a/enjoy/src/renderer/hooks/index.ts +++ b/enjoy/src/renderer/hooks/index.ts @@ -1,2 +1,3 @@ export * from './useTranscode'; export * from './useAiCommand'; +export * from './useConversation'; diff --git a/enjoy/src/renderer/hooks/useConversation.tsx b/enjoy/src/renderer/hooks/useConversation.tsx new file mode 100644 index 00000000..19e96803 --- /dev/null +++ b/enjoy/src/renderer/hooks/useConversation.tsx @@ -0,0 +1,213 @@ +import { + AppSettingsProviderContext, + AISettingsProviderContext, +} from "@renderer/context"; +import { useContext } from "react"; +import { ChatMessageHistory, BufferMemory } from "langchain/memory"; +import { ConversationChain } from "langchain/chains"; +import { ChatOpenAI } from "langchain/chat_models/openai"; +import { ChatOllama } from "langchain/chat_models/ollama"; +import { ChatGoogleGenerativeAI } from "@langchain/google-genai"; +import { ChatPromptTemplate, MessagesPlaceholder } from "langchain/prompts"; +import OpenAI, { type ClientOptions } from "openai"; +import { type Generation } from "langchain/dist/schema"; +import { v4 } from "uuid"; + +export const useConversation = () => { + const { EnjoyApp, user, apiUrl } = useContext(AppSettingsProviderContext); + const { openai, googleGenerativeAi, currentEngine } = useContext( + AISettingsProviderContext + ); + + const pickLlm = (conversation: ConversationType) => { + const { + baseUrl, + model, + temperature, + maxTokens, + frequencyPenalty, + presencePenalty, + numberOfChoices, + } = conversation.configuration; + + if (conversation.engine === "enjoyai") { + return new ChatOpenAI({ + openAIApiKey: user.accessToken, + configuration: { + baseURL: `${apiUrl}/api/ai`, + }, + modelName: model, + temperature, + maxTokens, + frequencyPenalty, + presencePenalty, + n: numberOfChoices, + }); + } else if (conversation.engine === "openai") { + return new ChatOpenAI({ + openAIApiKey: openai.key, + configuration: { + baseURL: baseUrl, + }, + modelName: model, + temperature, + maxTokens, + frequencyPenalty, + presencePenalty, + n: numberOfChoices, + }); + } else if (conversation.engine === "ollama") { + return new ChatOllama({ + baseUrl, + model, + temperature, + frequencyPenalty, + presencePenalty, + }); + } else if (conversation.engine === "googleGenerativeAi") { + return new ChatGoogleGenerativeAI({ + apiKey: googleGenerativeAi.key, + modelName: model, + temperature: temperature, + maxOutputTokens: maxTokens, + }); + } + }; + + const fetchChatHistory = async (conversation: ConversationType) => { + const chatMessageHistory = new ChatMessageHistory(); + let limit = conversation.configuration.historyBufferSize; + if (!limit || limit < 0) { + limit = 0; + } + const _messages: MessageType[] = await EnjoyApp.messages.findAll({ + where: { conversationId: conversation.id }, + order: [["createdAt", "DESC"]], + limit, + }); + + _messages + .sort( + (a, b) => + new Date(a.createdAt).getUTCMilliseconds() - + new Date(b.createdAt).getUTCMilliseconds() + ) + .forEach((message) => { + if (message.role === "user") { + chatMessageHistory.addUserMessage(message.content); + } else if (message.role === "assistant") { + chatMessageHistory.addAIChatMessage(message.content); + } + }); + + return chatMessageHistory; + }; + + const chat = async ( + message: Partial, + params: { + conversation: ConversationType; + } + ): Promise[]> => { + const { conversation } = params; + const chatHistory = await fetchChatHistory(conversation); + const memory = new BufferMemory({ + chatHistory, + memoryKey: "history", + returnMessages: true, + }); + const prompt = ChatPromptTemplate.fromMessages([ + ["system" as MessageRoleEnum, conversation.configuration.roleDefinition], + new MessagesPlaceholder("history"), + ["human", "{input}"], + ]); + + const llm = pickLlm(conversation); + const chain = new ConversationChain({ + // @ts-ignore + llm, + memory, + prompt, + verbose: true, + }); + let response: Generation[] = []; + await chain.call({ input: message.content }, [ + { + handleLLMEnd: async (output) => { + response = output.generations[0]; + }, + }, + ]); + + const replies = response.map((r) => { + return { + id: v4(), + content: r.text, + role: "assistant" as MessageRoleEnum, + conversationId: conversation.id, + }; + }); + + message.role = "user" as MessageRoleEnum; + message.conversationId = conversation.id; + + await EnjoyApp.messages.createInBatch([message, ...replies]); + + return replies; + }; + + const tts = async (params: Partial) => { + const { configuration } = params; + const { + engine = currentEngine.name, + model = "tts-1", + voice = "alloy", + baseUrl = currentEngine.baseUrl, + } = configuration || {}; + + let client: OpenAI; + + if (engine === "enjoyai") { + client = new OpenAI({ + apiKey: user.accessToken, + baseURL: `${apiUrl}/api/ai`, + dangerouslyAllowBrowser: true, + }); + } else { + client = new OpenAI({ + apiKey: openai.key, + baseURL: baseUrl, + dangerouslyAllowBrowser: true, + }); + } + + const file = await client.audio.speech.create({ + input: params.text, + model, + voice, + }); + const buffer = await file.arrayBuffer(); + + return EnjoyApp.speeches.create( + { + text: params.text, + sourceType: params.sourceType, + sourceId: params.sourceId, + configuration: { + engine, + model, + voice, + }, + }, + { + type: "audio/mp3", + arrayBuffer: buffer, + } + ); + }; + + return { + chat, + tts, + }; +}; diff --git a/enjoy/src/renderer/pages/conversation.tsx b/enjoy/src/renderer/pages/conversation.tsx index 26838253..8ac1af9e 100644 --- a/enjoy/src/renderer/pages/conversation.tsx +++ b/enjoy/src/renderer/pages/conversation.tsx @@ -19,6 +19,7 @@ import { import { messagesReducer } from "@renderer/reducers"; import { v4 as uuidv4 } from "uuid"; import autosize from "autosize"; +import { useConversation } from "@renderer/hooks"; export default () => { const { id } = useParams<{ id: string }>(); @@ -35,6 +36,7 @@ export default () => { const [messages, dispatchMessages] = useReducer(messagesReducer, []); const [offset, setOffest] = useState(0); const [loading, setLoading] = useState(false); + const { chat } = useConversation(); const inputRef = useRef(null); const submitRef = useRef(null); @@ -91,7 +93,7 @@ export default () => { const message: MessageType = { id: uuidv4(), content: text, - role: "user", + role: "user" as MessageRoleEnum, conversationId: id, status: "pending", }; @@ -118,15 +120,8 @@ export default () => { setSubmitting(false); }, 1000 * 60 * 5); - EnjoyApp.conversations - .ask(conversation.id, { - messageId: message.id, - content: message.content, - file, - }) - .then((reply) => { - if (reply) return; - + chat(message, { conversation }) + .catch(() => { message.status = "error"; dispatchMessages({ type: "update", record: message }); }) diff --git a/enjoy/src/types/enjoy-app.d.ts b/enjoy/src/types/enjoy-app.d.ts index 367a9afe..af4b9af8 100644 --- a/enjoy/src/types/enjoy-app.d.ts +++ b/enjoy/src/types/enjoy-app.d.ts @@ -183,9 +183,28 @@ type EnjoyAppType = { messages: { findAll: (params: object) => Promise; findOne: (params: object) => Promise; + createInBatch: (messages: Partial[]) => Promise; destroy: (id: string) => Promise; createSpeech: (id: string, configuration?: any) => Promise; }; + speeches: { + create: ( + params: { + sourceId: string; + sourceType: string; + text: string; + configuration: { + engine: string; + model: string; + voice: string; + }; + }, + blob: { + type: string; + arrayBuffer: ArrayBuffer; + } + ) => Promise; + }; whisper: { config: () => Promise; check: () => Promise<{ success: boolean; log: string }>; diff --git a/enjoy/src/types/message.d.ts b/enjoy/src/types/message.d.ts index 44862815..ebac8d15 100644 --- a/enjoy/src/types/message.d.ts +++ b/enjoy/src/types/message.d.ts @@ -1,10 +1,17 @@ +enum MessageRoleEnum { + SYSTEM = "system", + ASSISTANT = "assistant", + USER = "user", +} + type MessageType = { id: string; - role: "system" | "assistant" | "user"; + role: MessageRoleEnum; content: string; conversationId: string; conversation?: ConversationType; - createdAt?: string; + createdAt?: Date; + updatedAt?: Date; status?: "pending" | "success" | "error"; speeches?: Partial[]; recording?: RecordingType; diff --git a/enjoy/src/types/speech.d.ts b/enjoy/src/types/speech.d.ts index 09234a36..6c665c3c 100644 --- a/enjoy/src/types/speech.d.ts +++ b/enjoy/src/types/speech.d.ts @@ -10,6 +10,7 @@ type SpeechType = { md5: string; filename: string; filePath: string; + configuration: {[key: string]: any}; src?: string; createdAt: Date; updatedAt: Date;