Fix openai proxy (#244)
* add create messages in batch * add use conversation * update conversation shortcut * add speech handler * tts in renderer * fix speech create
This commit is contained in:
@@ -3,5 +3,6 @@ export * from './recordings-handler';
|
||||
export * from './messages-handler';
|
||||
export * from './conversations-handler';
|
||||
export * from './cache-objects-handler';
|
||||
export * from './speeches-handler';
|
||||
export * from './transcriptions-handler';
|
||||
export * from './videos-handler';
|
||||
|
||||
@@ -3,6 +3,7 @@ import { Message, Speech, Conversation } from "@main/db/models";
|
||||
import { FindOptions, WhereOptions, Attributes } from "sequelize";
|
||||
import log from "electron-log/main";
|
||||
import { t } from "i18next";
|
||||
import db from "@main/db";
|
||||
|
||||
class MessagesHandler {
|
||||
private async findAll(
|
||||
@@ -16,7 +17,7 @@ class MessagesHandler {
|
||||
model: Speech,
|
||||
where: { sourceType: "Message" },
|
||||
required: false,
|
||||
}
|
||||
},
|
||||
],
|
||||
order: [["createdAt", "DESC"]],
|
||||
...options,
|
||||
@@ -79,6 +80,25 @@ class MessagesHandler {
|
||||
});
|
||||
}
|
||||
|
||||
private async createInBatch(event: IpcMainEvent, messages: Message[]) {
|
||||
try {
|
||||
const transaction = await db.connection.transaction();
|
||||
for (const message of messages) {
|
||||
await Message.create(message, {
|
||||
include: [Conversation],
|
||||
transaction,
|
||||
});
|
||||
}
|
||||
|
||||
await transaction.commit();
|
||||
} catch (err) {
|
||||
event.sender.send("on-notification", {
|
||||
type: "error",
|
||||
message: err.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private async update(
|
||||
event: IpcMainEvent,
|
||||
id: string,
|
||||
@@ -150,6 +170,7 @@ class MessagesHandler {
|
||||
ipcMain.handle("messages-find-all", this.findAll);
|
||||
ipcMain.handle("messages-find-one", this.findOne);
|
||||
ipcMain.handle("messages-create", this.create);
|
||||
ipcMain.handle("messages-create-in-batch", this.createInBatch);
|
||||
ipcMain.handle("messages-update", this.update);
|
||||
ipcMain.handle("messages-destroy", this.destroy);
|
||||
ipcMain.handle("messages-create-speech", this.createSpeech);
|
||||
|
||||
50
enjoy/src/main/db/handlers/speeches-handler.ts
Normal file
50
enjoy/src/main/db/handlers/speeches-handler.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
import { ipcMain, IpcMainEvent } from "electron";
|
||||
import { Speech } from "@main/db/models";
|
||||
import fs from "fs-extra";
|
||||
import path from "path";
|
||||
import settings from "@main/settings";
|
||||
import { hashFile } from "@/utils";
|
||||
|
||||
class SpeechesHandler {
|
||||
private async create(
|
||||
event: IpcMainEvent,
|
||||
params: {
|
||||
sourceId: string;
|
||||
sourceType: string;
|
||||
text: string;
|
||||
configuration: {
|
||||
engine: string;
|
||||
model: string;
|
||||
voice: string;
|
||||
};
|
||||
},
|
||||
blob: {
|
||||
type: string;
|
||||
arrayBuffer: ArrayBuffer;
|
||||
}
|
||||
) {
|
||||
const format = blob.type.split("/")[1];
|
||||
const filename = `${Date.now()}.${format}`;
|
||||
const file = path.join(settings.userDataPath(), "speeches", filename);
|
||||
await fs.outputFile(file, Buffer.from(blob.arrayBuffer));
|
||||
const md5 = await hashFile(file, { algo: "md5" });
|
||||
fs.renameSync(file, path.join(path.dirname(file), `${md5}.${format}`));
|
||||
|
||||
return Speech.create({ ...params, extname: `.${format}`, md5 })
|
||||
.then((speech) => {
|
||||
return speech.toJSON();
|
||||
})
|
||||
.catch((err) => {
|
||||
event.sender.send("on-notification", {
|
||||
type: "error",
|
||||
message: err.message,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
register() {
|
||||
ipcMain.handle("speeches-create", this.create);
|
||||
}
|
||||
}
|
||||
|
||||
export const speechesHandler = new SpeechesHandler();
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
conversationsHandler,
|
||||
messagesHandler,
|
||||
recordingsHandler,
|
||||
speechesHandler,
|
||||
transcriptionsHandler,
|
||||
videosHandler,
|
||||
} from "./handlers";
|
||||
@@ -89,6 +90,7 @@ db.connect = async () => {
|
||||
recordingsHandler.register();
|
||||
conversationsHandler.register();
|
||||
messagesHandler.register();
|
||||
speechesHandler.register();
|
||||
transcriptionsHandler.register();
|
||||
videosHandler.register();
|
||||
};
|
||||
|
||||
@@ -288,7 +288,7 @@ export class Conversation extends Model<Conversation> {
|
||||
extra: [`--prompt "${prompt}"`],
|
||||
});
|
||||
content = transcription
|
||||
.map((t: TranscriptionSegmentType) => t.text)
|
||||
.map((t: TranscriptionResultSegmentType) => t.text)
|
||||
.join(" ")
|
||||
.trim();
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ export class Speech extends Model<Speech> {
|
||||
|
||||
@Column(DataType.VIRTUAL)
|
||||
get voice(): string {
|
||||
return this.getDataValue("configuration").model;
|
||||
return this.getDataValue("configuration").voice;
|
||||
}
|
||||
|
||||
@Column(DataType.VIRTUAL)
|
||||
|
||||
@@ -306,6 +306,9 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", {
|
||||
findOne: (params: object) => {
|
||||
return ipcRenderer.invoke("messages-find-one", params);
|
||||
},
|
||||
createInBatch: (messages: Partial<MessageType>[]) => {
|
||||
return ipcRenderer.invoke("messages-create-in-batch", messages);
|
||||
},
|
||||
destroy: (id: string) => {
|
||||
return ipcRenderer.invoke("messages-destroy", id);
|
||||
},
|
||||
@@ -313,6 +316,23 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", {
|
||||
return ipcRenderer.invoke("messages-create-speech", id, configuration);
|
||||
},
|
||||
},
|
||||
speeches: {
|
||||
create: (
|
||||
params: {
|
||||
sourceId: string;
|
||||
sourceType: string;
|
||||
text: string;
|
||||
configuration: {
|
||||
engine: string;
|
||||
model: string;
|
||||
voice: string;
|
||||
};
|
||||
},
|
||||
blob: { type: string; arrayBuffer: ArrayBuffer }
|
||||
) => {
|
||||
return ipcRenderer.invoke("speeches-create", params, blob);
|
||||
},
|
||||
},
|
||||
audiowaveform: {
|
||||
generate: (
|
||||
file: string,
|
||||
|
||||
@@ -4,16 +4,18 @@ import { Button, ScrollArea, toast } from "@renderer/components/ui";
|
||||
import { LoaderSpin } from "@renderer/components";
|
||||
import { MessageCircleIcon, LoaderIcon } from "lucide-react";
|
||||
import { t } from "i18next";
|
||||
import { useConversation } from "@renderer/hooks";
|
||||
|
||||
export const ConversationsShortcut = (props: {
|
||||
prompt: string;
|
||||
onReply?: (reply: MessageType[]) => void;
|
||||
onReply?: (reply: Partial<MessageType>[]) => void;
|
||||
}) => {
|
||||
const { EnjoyApp } = useContext(AppSettingsProviderContext);
|
||||
const { prompt, onReply } = props;
|
||||
const [conversations, setConversations] = useState<ConversationType[]>([]);
|
||||
const [loading, setLoading] = useState<boolean>(false);
|
||||
const [offset, setOffset] = useState<number>(0);
|
||||
const { chat } = useConversation();
|
||||
|
||||
const fetchConversations = () => {
|
||||
if (offset === -1) return;
|
||||
@@ -51,10 +53,8 @@ export const ConversationsShortcut = (props: {
|
||||
|
||||
const ask = (conversation: ConversationType) => {
|
||||
setLoading(true);
|
||||
EnjoyApp.conversations
|
||||
.ask(conversation.id, {
|
||||
content: prompt,
|
||||
})
|
||||
|
||||
chat({ content: prompt }, { conversation })
|
||||
.then((replies) => {
|
||||
onReply(replies);
|
||||
})
|
||||
|
||||
@@ -31,6 +31,7 @@ import { useCopyToClipboard } from "@uidotdev/usehooks";
|
||||
import { t } from "i18next";
|
||||
import { AppSettingsProviderContext } from "@renderer/context";
|
||||
import Markdown from "react-markdown";
|
||||
import { useConversation } from "@renderer/hooks";
|
||||
|
||||
export const AssistantMessageComponent = (props: {
|
||||
message: MessageType;
|
||||
@@ -46,6 +47,7 @@ export const AssistantMessageComponent = (props: {
|
||||
const [resourcing, setResourcing] = useState<boolean>(false);
|
||||
const [shadowing, setShadowing] = useState<boolean>(false);
|
||||
const { EnjoyApp } = useContext(AppSettingsProviderContext);
|
||||
const { tts } = useConversation();
|
||||
|
||||
useEffect(() => {
|
||||
if (speech) return;
|
||||
@@ -59,13 +61,12 @@ export const AssistantMessageComponent = (props: {
|
||||
|
||||
setSpeeching(true);
|
||||
|
||||
EnjoyApp.messages
|
||||
.createSpeech(message.id, {
|
||||
engine: configuration?.tts?.engine,
|
||||
model: configuration?.tts?.model,
|
||||
voice: configuration?.tts?.voice,
|
||||
baseUrl: configuration?.tts?.baseUrl,
|
||||
})
|
||||
tts({
|
||||
sourceType: "Message",
|
||||
sourceId: message.id,
|
||||
text: message.content,
|
||||
configuration: configuration.tts,
|
||||
})
|
||||
.then((speech) => {
|
||||
setSpeech(speech);
|
||||
})
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
export * from './useTranscode';
|
||||
export * from './useAiCommand';
|
||||
export * from './useConversation';
|
||||
|
||||
213
enjoy/src/renderer/hooks/useConversation.tsx
Normal file
213
enjoy/src/renderer/hooks/useConversation.tsx
Normal file
@@ -0,0 +1,213 @@
|
||||
import {
|
||||
AppSettingsProviderContext,
|
||||
AISettingsProviderContext,
|
||||
} from "@renderer/context";
|
||||
import { useContext } from "react";
|
||||
import { ChatMessageHistory, BufferMemory } from "langchain/memory";
|
||||
import { ConversationChain } from "langchain/chains";
|
||||
import { ChatOpenAI } from "langchain/chat_models/openai";
|
||||
import { ChatOllama } from "langchain/chat_models/ollama";
|
||||
import { ChatGoogleGenerativeAI } from "@langchain/google-genai";
|
||||
import { ChatPromptTemplate, MessagesPlaceholder } from "langchain/prompts";
|
||||
import OpenAI, { type ClientOptions } from "openai";
|
||||
import { type Generation } from "langchain/dist/schema";
|
||||
import { v4 } from "uuid";
|
||||
|
||||
export const useConversation = () => {
|
||||
const { EnjoyApp, user, apiUrl } = useContext(AppSettingsProviderContext);
|
||||
const { openai, googleGenerativeAi, currentEngine } = useContext(
|
||||
AISettingsProviderContext
|
||||
);
|
||||
|
||||
const pickLlm = (conversation: ConversationType) => {
|
||||
const {
|
||||
baseUrl,
|
||||
model,
|
||||
temperature,
|
||||
maxTokens,
|
||||
frequencyPenalty,
|
||||
presencePenalty,
|
||||
numberOfChoices,
|
||||
} = conversation.configuration;
|
||||
|
||||
if (conversation.engine === "enjoyai") {
|
||||
return new ChatOpenAI({
|
||||
openAIApiKey: user.accessToken,
|
||||
configuration: {
|
||||
baseURL: `${apiUrl}/api/ai`,
|
||||
},
|
||||
modelName: model,
|
||||
temperature,
|
||||
maxTokens,
|
||||
frequencyPenalty,
|
||||
presencePenalty,
|
||||
n: numberOfChoices,
|
||||
});
|
||||
} else if (conversation.engine === "openai") {
|
||||
return new ChatOpenAI({
|
||||
openAIApiKey: openai.key,
|
||||
configuration: {
|
||||
baseURL: baseUrl,
|
||||
},
|
||||
modelName: model,
|
||||
temperature,
|
||||
maxTokens,
|
||||
frequencyPenalty,
|
||||
presencePenalty,
|
||||
n: numberOfChoices,
|
||||
});
|
||||
} else if (conversation.engine === "ollama") {
|
||||
return new ChatOllama({
|
||||
baseUrl,
|
||||
model,
|
||||
temperature,
|
||||
frequencyPenalty,
|
||||
presencePenalty,
|
||||
});
|
||||
} else if (conversation.engine === "googleGenerativeAi") {
|
||||
return new ChatGoogleGenerativeAI({
|
||||
apiKey: googleGenerativeAi.key,
|
||||
modelName: model,
|
||||
temperature: temperature,
|
||||
maxOutputTokens: maxTokens,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const fetchChatHistory = async (conversation: ConversationType) => {
|
||||
const chatMessageHistory = new ChatMessageHistory();
|
||||
let limit = conversation.configuration.historyBufferSize;
|
||||
if (!limit || limit < 0) {
|
||||
limit = 0;
|
||||
}
|
||||
const _messages: MessageType[] = await EnjoyApp.messages.findAll({
|
||||
where: { conversationId: conversation.id },
|
||||
order: [["createdAt", "DESC"]],
|
||||
limit,
|
||||
});
|
||||
|
||||
_messages
|
||||
.sort(
|
||||
(a, b) =>
|
||||
new Date(a.createdAt).getUTCMilliseconds() -
|
||||
new Date(b.createdAt).getUTCMilliseconds()
|
||||
)
|
||||
.forEach((message) => {
|
||||
if (message.role === "user") {
|
||||
chatMessageHistory.addUserMessage(message.content);
|
||||
} else if (message.role === "assistant") {
|
||||
chatMessageHistory.addAIChatMessage(message.content);
|
||||
}
|
||||
});
|
||||
|
||||
return chatMessageHistory;
|
||||
};
|
||||
|
||||
const chat = async (
|
||||
message: Partial<MessageType>,
|
||||
params: {
|
||||
conversation: ConversationType;
|
||||
}
|
||||
): Promise<Partial<MessageType>[]> => {
|
||||
const { conversation } = params;
|
||||
const chatHistory = await fetchChatHistory(conversation);
|
||||
const memory = new BufferMemory({
|
||||
chatHistory,
|
||||
memoryKey: "history",
|
||||
returnMessages: true,
|
||||
});
|
||||
const prompt = ChatPromptTemplate.fromMessages([
|
||||
["system" as MessageRoleEnum, conversation.configuration.roleDefinition],
|
||||
new MessagesPlaceholder("history"),
|
||||
["human", "{input}"],
|
||||
]);
|
||||
|
||||
const llm = pickLlm(conversation);
|
||||
const chain = new ConversationChain({
|
||||
// @ts-ignore
|
||||
llm,
|
||||
memory,
|
||||
prompt,
|
||||
verbose: true,
|
||||
});
|
||||
let response: Generation[] = [];
|
||||
await chain.call({ input: message.content }, [
|
||||
{
|
||||
handleLLMEnd: async (output) => {
|
||||
response = output.generations[0];
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
const replies = response.map((r) => {
|
||||
return {
|
||||
id: v4(),
|
||||
content: r.text,
|
||||
role: "assistant" as MessageRoleEnum,
|
||||
conversationId: conversation.id,
|
||||
};
|
||||
});
|
||||
|
||||
message.role = "user" as MessageRoleEnum;
|
||||
message.conversationId = conversation.id;
|
||||
|
||||
await EnjoyApp.messages.createInBatch([message, ...replies]);
|
||||
|
||||
return replies;
|
||||
};
|
||||
|
||||
const tts = async (params: Partial<SpeechType>) => {
|
||||
const { configuration } = params;
|
||||
const {
|
||||
engine = currentEngine.name,
|
||||
model = "tts-1",
|
||||
voice = "alloy",
|
||||
baseUrl = currentEngine.baseUrl,
|
||||
} = configuration || {};
|
||||
|
||||
let client: OpenAI;
|
||||
|
||||
if (engine === "enjoyai") {
|
||||
client = new OpenAI({
|
||||
apiKey: user.accessToken,
|
||||
baseURL: `${apiUrl}/api/ai`,
|
||||
dangerouslyAllowBrowser: true,
|
||||
});
|
||||
} else {
|
||||
client = new OpenAI({
|
||||
apiKey: openai.key,
|
||||
baseURL: baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
});
|
||||
}
|
||||
|
||||
const file = await client.audio.speech.create({
|
||||
input: params.text,
|
||||
model,
|
||||
voice,
|
||||
});
|
||||
const buffer = await file.arrayBuffer();
|
||||
|
||||
return EnjoyApp.speeches.create(
|
||||
{
|
||||
text: params.text,
|
||||
sourceType: params.sourceType,
|
||||
sourceId: params.sourceId,
|
||||
configuration: {
|
||||
engine,
|
||||
model,
|
||||
voice,
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "audio/mp3",
|
||||
arrayBuffer: buffer,
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
return {
|
||||
chat,
|
||||
tts,
|
||||
};
|
||||
};
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
import { messagesReducer } from "@renderer/reducers";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import autosize from "autosize";
|
||||
import { useConversation } from "@renderer/hooks";
|
||||
|
||||
export default () => {
|
||||
const { id } = useParams<{ id: string }>();
|
||||
@@ -35,6 +36,7 @@ export default () => {
|
||||
const [messages, dispatchMessages] = useReducer(messagesReducer, []);
|
||||
const [offset, setOffest] = useState(0);
|
||||
const [loading, setLoading] = useState<boolean>(false);
|
||||
const { chat } = useConversation();
|
||||
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
const submitRef = useRef<HTMLButtonElement>(null);
|
||||
@@ -91,7 +93,7 @@ export default () => {
|
||||
const message: MessageType = {
|
||||
id: uuidv4(),
|
||||
content: text,
|
||||
role: "user",
|
||||
role: "user" as MessageRoleEnum,
|
||||
conversationId: id,
|
||||
status: "pending",
|
||||
};
|
||||
@@ -118,15 +120,8 @@ export default () => {
|
||||
setSubmitting(false);
|
||||
}, 1000 * 60 * 5);
|
||||
|
||||
EnjoyApp.conversations
|
||||
.ask(conversation.id, {
|
||||
messageId: message.id,
|
||||
content: message.content,
|
||||
file,
|
||||
})
|
||||
.then((reply) => {
|
||||
if (reply) return;
|
||||
|
||||
chat(message, { conversation })
|
||||
.catch(() => {
|
||||
message.status = "error";
|
||||
dispatchMessages({ type: "update", record: message });
|
||||
})
|
||||
|
||||
19
enjoy/src/types/enjoy-app.d.ts
vendored
19
enjoy/src/types/enjoy-app.d.ts
vendored
@@ -183,9 +183,28 @@ type EnjoyAppType = {
|
||||
messages: {
|
||||
findAll: (params: object) => Promise<MessageType[]>;
|
||||
findOne: (params: object) => Promise<MessageType>;
|
||||
createInBatch: (messages: Partial<MessageType>[]) => Promise<void>;
|
||||
destroy: (id: string) => Promise<void>;
|
||||
createSpeech: (id: string, configuration?: any) => Promise<SpeechType>;
|
||||
};
|
||||
speeches: {
|
||||
create: (
|
||||
params: {
|
||||
sourceId: string;
|
||||
sourceType: string;
|
||||
text: string;
|
||||
configuration: {
|
||||
engine: string;
|
||||
model: string;
|
||||
voice: string;
|
||||
};
|
||||
},
|
||||
blob: {
|
||||
type: string;
|
||||
arrayBuffer: ArrayBuffer;
|
||||
}
|
||||
) => Promise<SpeechType>;
|
||||
};
|
||||
whisper: {
|
||||
config: () => Promise<WhisperConfigType>;
|
||||
check: () => Promise<{ success: boolean; log: string }>;
|
||||
|
||||
11
enjoy/src/types/message.d.ts
vendored
11
enjoy/src/types/message.d.ts
vendored
@@ -1,10 +1,17 @@
|
||||
enum MessageRoleEnum {
|
||||
SYSTEM = "system",
|
||||
ASSISTANT = "assistant",
|
||||
USER = "user",
|
||||
}
|
||||
|
||||
type MessageType = {
|
||||
id: string;
|
||||
role: "system" | "assistant" | "user";
|
||||
role: MessageRoleEnum;
|
||||
content: string;
|
||||
conversationId: string;
|
||||
conversation?: ConversationType;
|
||||
createdAt?: string;
|
||||
createdAt?: Date;
|
||||
updatedAt?: Date;
|
||||
status?: "pending" | "success" | "error";
|
||||
speeches?: Partial<SpeechType>[];
|
||||
recording?: RecordingType;
|
||||
|
||||
1
enjoy/src/types/speech.d.ts
vendored
1
enjoy/src/types/speech.d.ts
vendored
@@ -10,6 +10,7 @@ type SpeechType = {
|
||||
md5: string;
|
||||
filename: string;
|
||||
filePath: string;
|
||||
configuration: {[key: string]: any};
|
||||
src?: string;
|
||||
createdAt: Date;
|
||||
updatedAt: Date;
|
||||
|
||||
Reference in New Issue
Block a user