Fix openai proxy (#244)

* add create messages in batch

* add use conversation

* update conversation shortcut

* add speech handler

* tts in renderer

* fix speech create
This commit is contained in:
an-lee
2024-02-02 00:41:23 +08:00
committed by GitHub
parent 05bfd46a88
commit abde169ead
15 changed files with 358 additions and 27 deletions

View File

@@ -3,5 +3,6 @@ export * from './recordings-handler';
export * from './messages-handler';
export * from './conversations-handler';
export * from './cache-objects-handler';
export * from './speeches-handler';
export * from './transcriptions-handler';
export * from './videos-handler';

View File

@@ -3,6 +3,7 @@ import { Message, Speech, Conversation } from "@main/db/models";
import { FindOptions, WhereOptions, Attributes } from "sequelize";
import log from "electron-log/main";
import { t } from "i18next";
import db from "@main/db";
class MessagesHandler {
private async findAll(
@@ -16,7 +17,7 @@ class MessagesHandler {
model: Speech,
where: { sourceType: "Message" },
required: false,
}
},
],
order: [["createdAt", "DESC"]],
...options,
@@ -79,6 +80,25 @@ class MessagesHandler {
});
}
private async createInBatch(event: IpcMainEvent, messages: Message[]) {
try {
const transaction = await db.connection.transaction();
for (const message of messages) {
await Message.create(message, {
include: [Conversation],
transaction,
});
}
await transaction.commit();
} catch (err) {
event.sender.send("on-notification", {
type: "error",
message: err.message,
});
}
}
private async update(
event: IpcMainEvent,
id: string,
@@ -150,6 +170,7 @@ class MessagesHandler {
ipcMain.handle("messages-find-all", this.findAll);
ipcMain.handle("messages-find-one", this.findOne);
ipcMain.handle("messages-create", this.create);
ipcMain.handle("messages-create-in-batch", this.createInBatch);
ipcMain.handle("messages-update", this.update);
ipcMain.handle("messages-destroy", this.destroy);
ipcMain.handle("messages-create-speech", this.createSpeech);

View File

@@ -0,0 +1,50 @@
import { ipcMain, IpcMainEvent } from "electron";
import { Speech } from "@main/db/models";
import fs from "fs-extra";
import path from "path";
import settings from "@main/settings";
import { hashFile } from "@/utils";
class SpeechesHandler {
private async create(
event: IpcMainEvent,
params: {
sourceId: string;
sourceType: string;
text: string;
configuration: {
engine: string;
model: string;
voice: string;
};
},
blob: {
type: string;
arrayBuffer: ArrayBuffer;
}
) {
const format = blob.type.split("/")[1];
const filename = `${Date.now()}.${format}`;
const file = path.join(settings.userDataPath(), "speeches", filename);
await fs.outputFile(file, Buffer.from(blob.arrayBuffer));
const md5 = await hashFile(file, { algo: "md5" });
fs.renameSync(file, path.join(path.dirname(file), `${md5}.${format}`));
return Speech.create({ ...params, extname: `.${format}`, md5 })
.then((speech) => {
return speech.toJSON();
})
.catch((err) => {
event.sender.send("on-notification", {
type: "error",
message: err.message,
});
});
}
register() {
ipcMain.handle("speeches-create", this.create);
}
}
export const speechesHandler = new SpeechesHandler();

View File

@@ -19,6 +19,7 @@ import {
conversationsHandler,
messagesHandler,
recordingsHandler,
speechesHandler,
transcriptionsHandler,
videosHandler,
} from "./handlers";
@@ -89,6 +90,7 @@ db.connect = async () => {
recordingsHandler.register();
conversationsHandler.register();
messagesHandler.register();
speechesHandler.register();
transcriptionsHandler.register();
videosHandler.register();
};

View File

@@ -288,7 +288,7 @@ export class Conversation extends Model<Conversation> {
extra: [`--prompt "${prompt}"`],
});
content = transcription
.map((t: TranscriptionSegmentType) => t.text)
.map((t: TranscriptionResultSegmentType) => t.text)
.join(" ")
.trim();

View File

@@ -92,7 +92,7 @@ export class Speech extends Model<Speech> {
@Column(DataType.VIRTUAL)
get voice(): string {
return this.getDataValue("configuration").model;
return this.getDataValue("configuration").voice;
}
@Column(DataType.VIRTUAL)

View File

@@ -306,6 +306,9 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", {
findOne: (params: object) => {
return ipcRenderer.invoke("messages-find-one", params);
},
createInBatch: (messages: Partial<MessageType>[]) => {
return ipcRenderer.invoke("messages-create-in-batch", messages);
},
destroy: (id: string) => {
return ipcRenderer.invoke("messages-destroy", id);
},
@@ -313,6 +316,23 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", {
return ipcRenderer.invoke("messages-create-speech", id, configuration);
},
},
speeches: {
create: (
params: {
sourceId: string;
sourceType: string;
text: string;
configuration: {
engine: string;
model: string;
voice: string;
};
},
blob: { type: string; arrayBuffer: ArrayBuffer }
) => {
return ipcRenderer.invoke("speeches-create", params, blob);
},
},
audiowaveform: {
generate: (
file: string,

View File

@@ -4,16 +4,18 @@ import { Button, ScrollArea, toast } from "@renderer/components/ui";
import { LoaderSpin } from "@renderer/components";
import { MessageCircleIcon, LoaderIcon } from "lucide-react";
import { t } from "i18next";
import { useConversation } from "@renderer/hooks";
export const ConversationsShortcut = (props: {
prompt: string;
onReply?: (reply: MessageType[]) => void;
onReply?: (reply: Partial<MessageType>[]) => void;
}) => {
const { EnjoyApp } = useContext(AppSettingsProviderContext);
const { prompt, onReply } = props;
const [conversations, setConversations] = useState<ConversationType[]>([]);
const [loading, setLoading] = useState<boolean>(false);
const [offset, setOffset] = useState<number>(0);
const { chat } = useConversation();
const fetchConversations = () => {
if (offset === -1) return;
@@ -51,10 +53,8 @@ export const ConversationsShortcut = (props: {
const ask = (conversation: ConversationType) => {
setLoading(true);
EnjoyApp.conversations
.ask(conversation.id, {
content: prompt,
})
chat({ content: prompt }, { conversation })
.then((replies) => {
onReply(replies);
})

View File

@@ -31,6 +31,7 @@ import { useCopyToClipboard } from "@uidotdev/usehooks";
import { t } from "i18next";
import { AppSettingsProviderContext } from "@renderer/context";
import Markdown from "react-markdown";
import { useConversation } from "@renderer/hooks";
export const AssistantMessageComponent = (props: {
message: MessageType;
@@ -46,6 +47,7 @@ export const AssistantMessageComponent = (props: {
const [resourcing, setResourcing] = useState<boolean>(false);
const [shadowing, setShadowing] = useState<boolean>(false);
const { EnjoyApp } = useContext(AppSettingsProviderContext);
const { tts } = useConversation();
useEffect(() => {
if (speech) return;
@@ -59,13 +61,12 @@ export const AssistantMessageComponent = (props: {
setSpeeching(true);
EnjoyApp.messages
.createSpeech(message.id, {
engine: configuration?.tts?.engine,
model: configuration?.tts?.model,
voice: configuration?.tts?.voice,
baseUrl: configuration?.tts?.baseUrl,
})
tts({
sourceType: "Message",
sourceId: message.id,
text: message.content,
configuration: configuration.tts,
})
.then((speech) => {
setSpeech(speech);
})

View File

@@ -1,2 +1,3 @@
export * from './useTranscode';
export * from './useAiCommand';
export * from './useConversation';

View File

@@ -0,0 +1,213 @@
import {
AppSettingsProviderContext,
AISettingsProviderContext,
} from "@renderer/context";
import { useContext } from "react";
import { ChatMessageHistory, BufferMemory } from "langchain/memory";
import { ConversationChain } from "langchain/chains";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { ChatOllama } from "langchain/chat_models/ollama";
import { ChatGoogleGenerativeAI } from "@langchain/google-genai";
import { ChatPromptTemplate, MessagesPlaceholder } from "langchain/prompts";
import OpenAI, { type ClientOptions } from "openai";
import { type Generation } from "langchain/dist/schema";
import { v4 } from "uuid";
export const useConversation = () => {
const { EnjoyApp, user, apiUrl } = useContext(AppSettingsProviderContext);
const { openai, googleGenerativeAi, currentEngine } = useContext(
AISettingsProviderContext
);
const pickLlm = (conversation: ConversationType) => {
const {
baseUrl,
model,
temperature,
maxTokens,
frequencyPenalty,
presencePenalty,
numberOfChoices,
} = conversation.configuration;
if (conversation.engine === "enjoyai") {
return new ChatOpenAI({
openAIApiKey: user.accessToken,
configuration: {
baseURL: `${apiUrl}/api/ai`,
},
modelName: model,
temperature,
maxTokens,
frequencyPenalty,
presencePenalty,
n: numberOfChoices,
});
} else if (conversation.engine === "openai") {
return new ChatOpenAI({
openAIApiKey: openai.key,
configuration: {
baseURL: baseUrl,
},
modelName: model,
temperature,
maxTokens,
frequencyPenalty,
presencePenalty,
n: numberOfChoices,
});
} else if (conversation.engine === "ollama") {
return new ChatOllama({
baseUrl,
model,
temperature,
frequencyPenalty,
presencePenalty,
});
} else if (conversation.engine === "googleGenerativeAi") {
return new ChatGoogleGenerativeAI({
apiKey: googleGenerativeAi.key,
modelName: model,
temperature: temperature,
maxOutputTokens: maxTokens,
});
}
};
const fetchChatHistory = async (conversation: ConversationType) => {
const chatMessageHistory = new ChatMessageHistory();
let limit = conversation.configuration.historyBufferSize;
if (!limit || limit < 0) {
limit = 0;
}
const _messages: MessageType[] = await EnjoyApp.messages.findAll({
where: { conversationId: conversation.id },
order: [["createdAt", "DESC"]],
limit,
});
_messages
.sort(
(a, b) =>
new Date(a.createdAt).getUTCMilliseconds() -
new Date(b.createdAt).getUTCMilliseconds()
)
.forEach((message) => {
if (message.role === "user") {
chatMessageHistory.addUserMessage(message.content);
} else if (message.role === "assistant") {
chatMessageHistory.addAIChatMessage(message.content);
}
});
return chatMessageHistory;
};
const chat = async (
message: Partial<MessageType>,
params: {
conversation: ConversationType;
}
): Promise<Partial<MessageType>[]> => {
const { conversation } = params;
const chatHistory = await fetchChatHistory(conversation);
const memory = new BufferMemory({
chatHistory,
memoryKey: "history",
returnMessages: true,
});
const prompt = ChatPromptTemplate.fromMessages([
["system" as MessageRoleEnum, conversation.configuration.roleDefinition],
new MessagesPlaceholder("history"),
["human", "{input}"],
]);
const llm = pickLlm(conversation);
const chain = new ConversationChain({
// @ts-ignore
llm,
memory,
prompt,
verbose: true,
});
let response: Generation[] = [];
await chain.call({ input: message.content }, [
{
handleLLMEnd: async (output) => {
response = output.generations[0];
},
},
]);
const replies = response.map((r) => {
return {
id: v4(),
content: r.text,
role: "assistant" as MessageRoleEnum,
conversationId: conversation.id,
};
});
message.role = "user" as MessageRoleEnum;
message.conversationId = conversation.id;
await EnjoyApp.messages.createInBatch([message, ...replies]);
return replies;
};
const tts = async (params: Partial<SpeechType>) => {
const { configuration } = params;
const {
engine = currentEngine.name,
model = "tts-1",
voice = "alloy",
baseUrl = currentEngine.baseUrl,
} = configuration || {};
let client: OpenAI;
if (engine === "enjoyai") {
client = new OpenAI({
apiKey: user.accessToken,
baseURL: `${apiUrl}/api/ai`,
dangerouslyAllowBrowser: true,
});
} else {
client = new OpenAI({
apiKey: openai.key,
baseURL: baseUrl,
dangerouslyAllowBrowser: true,
});
}
const file = await client.audio.speech.create({
input: params.text,
model,
voice,
});
const buffer = await file.arrayBuffer();
return EnjoyApp.speeches.create(
{
text: params.text,
sourceType: params.sourceType,
sourceId: params.sourceId,
configuration: {
engine,
model,
voice,
},
},
{
type: "audio/mp3",
arrayBuffer: buffer,
}
);
};
return {
chat,
tts,
};
};

View File

@@ -19,6 +19,7 @@ import {
import { messagesReducer } from "@renderer/reducers";
import { v4 as uuidv4 } from "uuid";
import autosize from "autosize";
import { useConversation } from "@renderer/hooks";
export default () => {
const { id } = useParams<{ id: string }>();
@@ -35,6 +36,7 @@ export default () => {
const [messages, dispatchMessages] = useReducer(messagesReducer, []);
const [offset, setOffest] = useState(0);
const [loading, setLoading] = useState<boolean>(false);
const { chat } = useConversation();
const inputRef = useRef<HTMLTextAreaElement>(null);
const submitRef = useRef<HTMLButtonElement>(null);
@@ -91,7 +93,7 @@ export default () => {
const message: MessageType = {
id: uuidv4(),
content: text,
role: "user",
role: "user" as MessageRoleEnum,
conversationId: id,
status: "pending",
};
@@ -118,15 +120,8 @@ export default () => {
setSubmitting(false);
}, 1000 * 60 * 5);
EnjoyApp.conversations
.ask(conversation.id, {
messageId: message.id,
content: message.content,
file,
})
.then((reply) => {
if (reply) return;
chat(message, { conversation })
.catch(() => {
message.status = "error";
dispatchMessages({ type: "update", record: message });
})

View File

@@ -183,9 +183,28 @@ type EnjoyAppType = {
messages: {
findAll: (params: object) => Promise<MessageType[]>;
findOne: (params: object) => Promise<MessageType>;
createInBatch: (messages: Partial<MessageType>[]) => Promise<void>;
destroy: (id: string) => Promise<void>;
createSpeech: (id: string, configuration?: any) => Promise<SpeechType>;
};
speeches: {
create: (
params: {
sourceId: string;
sourceType: string;
text: string;
configuration: {
engine: string;
model: string;
voice: string;
};
},
blob: {
type: string;
arrayBuffer: ArrayBuffer;
}
) => Promise<SpeechType>;
};
whisper: {
config: () => Promise<WhisperConfigType>;
check: () => Promise<{ success: boolean; log: string }>;

View File

@@ -1,10 +1,17 @@
enum MessageRoleEnum {
SYSTEM = "system",
ASSISTANT = "assistant",
USER = "user",
}
type MessageType = {
id: string;
role: "system" | "assistant" | "user";
role: MessageRoleEnum;
content: string;
conversationId: string;
conversation?: ConversationType;
createdAt?: string;
createdAt?: Date;
updatedAt?: Date;
status?: "pending" | "success" | "error";
speeches?: Partial<SpeechType>[];
recording?: RecordingType;

View File

@@ -10,6 +10,7 @@ type SpeechType = {
md5: string;
filename: string;
filePath: string;
configuration: {[key: string]: any};
src?: string;
createdAt: Date;
updatedAt: Date;