Refactor components codes (#538)

* remove deprecated code

* refactor code

* refactor components codes

* fix renderer tests
This commit is contained in:
an-lee
2024-04-19 10:46:04 +08:00
committed by GitHub
parent 5f3ee54bb5
commit e4f5bdcfb9
30 changed files with 509 additions and 1147 deletions

View File

@@ -113,45 +113,12 @@ class ConversationsHandler {
});
}
private async ask(
event: IpcMainEvent,
id: string,
params: {
messageId: string;
content?: string;
file?: string;
blob?: {
type: string;
arrayBuffer: ArrayBuffer;
};
}
) {
const conversation = await Conversation.findOne({
where: { id },
});
if (!conversation) {
event.sender.send("on-notification", {
type: "error",
message: t("models.conversation.notFound"),
});
return;
}
return conversation.ask(params).catch((err) => {
event.sender.send("on-notification", {
type: "error",
message: err.message,
});
});
}
register() {
ipcMain.handle("conversations-find-all", this.findAll);
ipcMain.handle("conversations-find-one", this.findOne);
ipcMain.handle("conversations-create", this.create);
ipcMain.handle("conversations-update", this.update);
ipcMain.handle("conversations-destroy", this.destroy);
ipcMain.handle("conversations-ask", this.ask);
}
}

View File

@@ -1,4 +1,3 @@
import { app } from "electron";
import {
AfterCreate,
AfterDestroy,
@@ -13,28 +12,8 @@ import {
AllowNull,
} from "sequelize-typescript";
import { Message, Speech } from "@main/db/models";
import { ChatMessageHistory, BufferMemory } from "langchain/memory";
import { ConversationChain } from "langchain/chains";
import { ChatOpenAI } from "@langchain/openai";
import { ChatOllama } from "@langchain/community/chat_models/ollama";
import { ChatGoogleGenerativeAI } from "@langchain/google-genai";
import {
ChatPromptTemplate,
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { type Generation } from "langchain/dist/schema";
import settings from "@main/settings";
import db from "@main/db";
import mainWindow from "@main/window";
import { t } from "i18next";
import log from "@main/logger";
import fs from "fs-extra";
import path from "path";
import Ffmpeg from "@main/ffmpeg";
import whisper from "@main/whisper";
import { hashFile } from "@main/utils";
import { WEB_API_URL } from "@/constants";
import proxyAgent from "@main/proxy-agent";
const logger = log.scope("db/models/conversation");
@Table({
@@ -68,7 +47,7 @@ export class Conversation extends Model<Conversation> {
} & { [key: string]: any };
@Column(DataType.VIRTUAL)
get type(): 'gpt' | 'tts' {
get type(): "gpt" | "tts" {
return this.getDataValue("configuration").type || "gpt";
}
@@ -117,263 +96,4 @@ export class Conversation extends Model<Conversation> {
record: conversation.toJSON(),
});
}
// convert messages to chat history
async chatHistory() {
const chatMessageHistory = new ChatMessageHistory();
let limit = this.configuration.historyBufferSize;
if (!limit || limit < 0) {
limit = 0;
}
const _messages = await Message.findAll({
where: { conversationId: this.id },
order: [["createdAt", "DESC"]],
limit,
});
logger.debug(_messages);
_messages
.sort((a, b) => a.createdAt - b.createdAt)
.forEach((message) => {
if (message.role === "user") {
chatMessageHistory.addUserMessage(message.content);
} else if (message.role === "assistant") {
chatMessageHistory.addAIChatMessage(message.content);
}
});
return chatMessageHistory;
}
// choose llm based on engine
llm() {
const { httpAgent, fetch } = proxyAgent();
if (this.engine === "enjoyai") {
return new ChatOpenAI(
{
openAIApiKey: settings.getSync("user.accessToken") as string,
modelName: this.model,
configuration: {
baseURL: `${process.env.WEB_API_URL || WEB_API_URL}/api/ai`,
},
temperature: this.configuration.temperature,
n: this.configuration.numberOfChoices,
maxTokens: this.configuration.maxTokens,
frequencyPenalty: this.configuration.frequencyPenalty,
presencePenalty: this.configuration.presencePenalty,
},
{
httpAgent,
// @ts-ignore
fetch,
}
);
} else if (this.engine === "openai") {
const key = settings.getSync("openai.key") as string;
if (!key) {
throw new Error(t("openaiKeyRequired"));
}
return new ChatOpenAI(
{
openAIApiKey: key,
modelName: this.model,
configuration: {
baseURL: this.configuration.baseUrl,
},
temperature: this.configuration.temperature,
n: this.configuration.numberOfChoices,
maxTokens: this.configuration.maxTokens,
frequencyPenalty: this.configuration.frequencyPenalty,
presencePenalty: this.configuration.presencePenalty,
},
{
httpAgent,
// @ts-ignore
fetch,
}
);
} else if (this.engine === "googleGenerativeAi") {
const key = settings.getSync("googleGenerativeAi.key") as string;
if (!key) {
throw new Error(t("googleGenerativeAiKeyRequired"));
}
return new ChatGoogleGenerativeAI({
apiKey: key,
modelName: this.model,
temperature: this.configuration.temperature,
maxOutputTokens: this.configuration.maxTokens,
});
} else if (this.engine == "ollama") {
return new ChatOllama({
baseUrl: this.configuration.baseUrl,
model: this.model,
temperature: this.configuration.temperature,
frequencyPenalty: this.configuration.frequencyPenalty,
presencePenalty: this.configuration.presencePenalty,
});
}
}
// choose memory based on conversation scenario
async memory() {
const chatHistory = await this.chatHistory();
return new BufferMemory({
chatHistory,
memoryKey: "history",
returnMessages: true,
});
}
chatPrompt() {
return ChatPromptTemplate.fromMessages([
["system", this.roleDefinition],
new MessagesPlaceholder("history"),
["human", "{input}"],
]);
}
async chain() {
return new ConversationChain({
llm: this.llm(),
memory: await this.memory(),
prompt: this.chatPrompt(),
verbose: app.isPackaged ? false : true,
});
}
async ask(params: {
messageId?: string;
content?: string;
file?: string;
blob?: {
type: string;
arrayBuffer: ArrayBuffer;
};
}) {
let { content, file, blob, messageId } = params;
if (!content && !blob) {
throw new Error(t("models.conversation.contentRequired"));
}
let md5 = "";
let extname = ".wav";
if (file) {
extname = path.extname(file);
md5 = await hashFile(file, { algo: "md5" });
fs.copySync(
file,
path.join(settings.userDataPath(), "speeches", `${md5}${extname}`)
);
} else if (blob) {
const filename = `${Date.now()}${extname}`;
const format = blob.type.split("/")[1];
const tempfile = path.join(
settings.cachePath(),
`${Date.now()}.${format}`
);
await fs.outputFile(tempfile, Buffer.from(blob.arrayBuffer));
const wavFile = path.join(settings.userDataPath(), "speeches", filename);
const ffmpeg = new Ffmpeg();
await ffmpeg.convertToWav(tempfile, wavFile);
md5 = await hashFile(wavFile, { algo: "md5" });
fs.renameSync(
wavFile,
path.join(path.dirname(wavFile), `${md5}${extname}`)
);
const previousMessage = await Message.findOne({
where: { conversationId: this.id },
order: [["createdAt", "DESC"]],
});
let prompt = "";
if (previousMessage?.content) {
prompt = previousMessage.content.replace(/"/g, '\\"');
}
const { transcription } = await whisper.transcribe(wavFile, {
force: true,
extra: [`--prompt "${prompt}"`],
});
content = transcription
.map((t: TranscriptionResultSegmentType) => t.text)
.join(" ")
.trim();
logger.debug("transcription", transcription);
}
const chain = await this.chain();
let response: Generation[] = [];
const result = await chain.call({ input: content }, [
{
handleLLMEnd: async (output) => {
response = output.generations[0];
},
},
]);
logger.debug("LLM result:", result);
if (!response) {
throw new Error(t("models.conversation.failedToGenerateResponse"));
}
const transaction = await db.connection.transaction();
await Message.create(
{
id: messageId,
conversationId: this.id,
role: "user",
content,
},
{
include: [Conversation],
transaction,
}
);
const replies = await Promise.all(
response.map(async (generation) => {
if (!generation?.text) {
throw new Error(t("models.conversation.failedToGenerateResponse"));
}
return await Message.create(
{
conversationId: this.id,
role: "assistant",
content: generation.text,
},
{
include: [Conversation],
transaction,
}
);
})
);
if (md5) {
await Speech.create(
{
sourceId: messageId,
sourceType: "message",
text: content,
md5,
extname,
configuration: {
engine: "Human",
},
},
{
include: [Message],
transaction,
}
);
}
await transaction.commit();
return replies.map((reply) => reply.toJSON());
}
}