Feat: add Enjoy AI as option (#206)

* add enjoyAI as option

* use enjoyai config

* may call enjoyai

* may set default ai engine

* refactor setting context

* refactor preferences

* add warning when openai key not provided

* tweak locale

* update duration for audio/video

* add balance settings

* may select ai role when create conversation

* may forward message from conversation

* tweak ui

* refactor transcribe method

* refactor ai commands to hooks

* fix webapi

* tweak playback rate options

* add playMode, next & prev, ref: #124

* upgrade deps

* may skip whisper model download

* audios/videos default order by updated_At
This commit is contained in:
an-lee
2024-01-31 00:04:59 +08:00
committed by GitHub
parent 58dcd1523e
commit 00cbc8403b
56 changed files with 1590 additions and 858 deletions

View File

@@ -14,7 +14,7 @@ class AudiosHandler {
options: FindOptions<Attributes<Audio>>
) {
return Audio.findAll({
order: [["createdAt", "DESC"]],
order: [["updatedAt", "DESC"]],
include: [
{
association: "transcription",
@@ -66,39 +66,6 @@ class AudiosHandler {
});
}
private async transcribe(event: IpcMainEvent, id: string) {
const audio = await Audio.findOne({
where: {
id,
},
});
if (!audio) {
event.sender.send("on-notification", {
type: "error",
message: t("models.audio.notFound"),
});
}
const timeout = setTimeout(() => {
event.sender.send("on-notification", {
type: "warning",
message: t("stillTranscribing"),
});
}, 1000 * 10);
audio
.transcribe()
.catch((err) => {
event.sender.send("on-notification", {
type: "error",
message: err.message,
});
})
.finally(() => {
clearTimeout(timeout);
});
}
private async create(
event: IpcMainEvent,
uri: string,
@@ -148,7 +115,7 @@ class AudiosHandler {
id: string,
params: Attributes<Audio>
) {
const { name, description, transcription } = params;
const { name, description, metadata } = params;
return Audio.findOne({
where: { id },
@@ -157,7 +124,7 @@ class AudiosHandler {
if (!audio) {
throw new Error(t("models.audio.notFound"));
}
audio.update({ name, description, transcription });
audio.update({ name, description, metadata });
})
.catch((err) => {
event.sender.send("on-notification", {
@@ -208,7 +175,6 @@ class AudiosHandler {
register() {
ipcMain.handle("audios-find-all", this.findAll);
ipcMain.handle("audios-find-one", this.findOne);
ipcMain.handle("audios-transcribe", this.transcribe);
ipcMain.handle("audios-create", this.create);
ipcMain.handle("audios-update", this.update);
ipcMain.handle("audios-destroy", this.destroy);

View File

@@ -86,7 +86,7 @@ class TranscriptionsHandler {
throw new Error("models.transcription.notFound");
}
const timeout = setTimeout(() => {
const interval = setInterval(() => {
event.sender.send("on-notification", {
type: "warning",
message: t("stillTranscribing"),
@@ -102,7 +102,7 @@ class TranscriptionsHandler {
});
})
.finally(() => {
clearTimeout(timeout);
clearInterval(interval);
});
})
.catch((err) => {

View File

@@ -14,7 +14,7 @@ class VideosHandler {
options: FindOptions<Attributes<Video>>
) {
return Video.findAll({
order: [["createdAt", "DESC"]],
order: [["updatedAt", "DESC"]],
include: [
{
association: "transcription",
@@ -66,39 +66,6 @@ class VideosHandler {
});
}
private async transcribe(event: IpcMainEvent, id: string) {
const video = await Video.findOne({
where: {
id,
},
});
if (!video) {
event.sender.send("on-notification", {
type: "error",
message: t("models.video.notFound"),
});
}
const timeout = setTimeout(() => {
event.sender.send("on-notification", {
type: "warning",
message: t("stillTranscribing"),
});
}, 1000 * 10);
video
.transcribe()
.catch((err) => {
event.sender.send("on-notification", {
type: "error",
message: err.message,
});
})
.finally(() => {
clearTimeout(timeout);
});
}
private async create(
event: IpcMainEvent,
uri: string,
@@ -149,7 +116,7 @@ class VideosHandler {
id: string,
params: Attributes<Video>
) {
const { name, description, transcription } = params;
const { name, description, metadata } = params;
return Video.findOne({
where: { id },
@@ -158,7 +125,7 @@ class VideosHandler {
if (!video) {
throw new Error(t("models.video.notFound"));
}
video.update({ name, description, transcription });
video.update({ name, description, metadata });
})
.catch((err) => {
event.sender.send("on-notification", {
@@ -209,7 +176,6 @@ class VideosHandler {
register() {
ipcMain.handle("videos-find-all", this.findAll);
ipcMain.handle("videos-find-one", this.findOne);
ipcMain.handle("videos-transcribe", this.transcribe);
ipcMain.handle("videos-create", this.create);
ipcMain.handle("videos-update", this.update);
ipcMain.handle("videos-destroy", this.destroy);

View File

@@ -34,12 +34,6 @@ const SIZE_LIMIT = 1024 * 1024 * 50; // 50MB
const logger = log.scope("db/models/audio");
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
@Table({
modelName: "Audio",
tableName: "audios",
@@ -119,7 +113,7 @@ export class Audio extends Model<Audio> {
@Column(DataType.VIRTUAL)
get transcribed(): boolean {
return this.transcription?.state === "finished";
return Boolean(this.transcription?.result);
}
@Column(DataType.VIRTUAL)
@@ -131,6 +125,11 @@ export class Audio extends Model<Audio> {
)}`;
}
@Column(DataType.VIRTUAL)
get duration(): number {
return this.getDataValue("metadata").duration;
}
get extname(): string {
return (
this.getDataValue("metadata").extname ||
@@ -167,9 +166,13 @@ export class Audio extends Model<Audio> {
}
async sync() {
if (!this.isUploaded) {
this.upload();
}
if (this.isSynced) return;
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
return webApi.syncAudio(this.toJSON()).then(() => {
this.update({ syncedAt: new Date() });
@@ -212,6 +215,7 @@ export class Audio extends Model<Audio> {
@AfterUpdate
static notifyForUpdate(audio: Audio) {
this.notify(audio, "update");
audio.sync().catch(() => {});
}
@AfterDestroy

View File

@@ -30,6 +30,7 @@ import path from "path";
import Ffmpeg from "@main/ffmpeg";
import whisper from "@main/whisper";
import { hashFile } from "@/utils";
import { WEB_API_URL } from "@/constants";
const logger = log.scope("db/models/conversation");
@Table({
@@ -136,7 +137,22 @@ export class Conversation extends Model<Conversation> {
// choose llm based on engine
llm() {
if (this.engine == "openai") {
if (this.engine === "enjoyai") {
return new ChatOpenAI({
modelName: this.model,
configuration: {
baseURL: `${process.env.WEB_API_URL || WEB_API_URL}/api/ai`,
defaultHeaders: {
Authorization: `Bearer ${settings.getSync("user.accessToken")}`,
},
},
temperature: this.configuration.temperature,
n: this.configuration.numberOfChoices,
maxTokens: this.configuration.maxTokens,
frequencyPenalty: this.configuration.frequencyPenalty,
presencePenalty: this.configuration.presencePenalty,
});
} else if (this.engine === "openai") {
const key = settings.getSync("openai.key") as string;
if (!key) {
throw new Error(t("openaiKeyRequired"));

View File

@@ -19,12 +19,6 @@ import { WEB_API_URL } from "@/constants";
import settings from "@main/settings";
import log from "electron-log/main";
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
@Table({
modelName: "PronunciationAssessment",
tableName: "pronunciation_assessments",
@@ -40,7 +34,7 @@ const webApi = new Client({
},
}))
export class PronunciationAssessment extends Model<PronunciationAssessment> {
@IsUUID('all')
@IsUUID("all")
@Default(DataType.UUIDV4)
@Column({ primaryKey: true, type: DataType.UUID })
id: string;
@@ -100,6 +94,12 @@ export class PronunciationAssessment extends Model<PronunciationAssessment> {
}
async sync() {
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
return webApi.syncPronunciationAssessment(this.toJSON()).then(() => {
this.update({ syncedAt: new Date() });
});

View File

@@ -29,12 +29,6 @@ import camelcaseKeys from "camelcase-keys";
const logger = log.scope("db/models/recording");
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
@Table({
modelName: "Recording",
tableName: "recordings",
@@ -144,6 +138,12 @@ export class Recording extends Model<Recording> {
await this.upload();
}
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
return webApi.syncRecording(this.toJSON()).then(() => {
this.update({ syncedAt: new Date() });
});
@@ -158,6 +158,12 @@ export class Recording extends Model<Recording> {
return assessment;
}
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
const { token, region } = await webApi.generateSpeechToken();
const sdk = new AzureSpeechSdk(token, region);

View File

@@ -23,6 +23,7 @@ import { t } from "i18next";
import { hashFile } from "@/utils";
import { Audio, Message } from "@main/db/models";
import log from "electron-log/main";
import { WEB_API_URL } from "@/constants";
const logger = log.scope("db/models/speech");
@Table({
@@ -170,26 +171,34 @@ export class Speech extends Model<Speech> {
const filename = `${Date.now()}${extname}`;
const filePath = path.join(settings.userDataPath(), "speeches", filename);
if (engine === "openai") {
const key = settings.getSync("openai.key") as string;
if (!key) {
let openaiConfig = {};
if (engine === "enjoyai") {
openaiConfig = {
baseURL: `${process.env.WEB_API_URL || WEB_API_URL}/api/ai`,
defaultHeaders: {
Authorization: `Bearer ${settings.getSync("user.accessToken")}`,
},
};
} else if (engine === "openai") {
const defaultConfig = settings.getSync("openai") as LlmProviderType;
if (!defaultConfig.key) {
throw new Error(t("openaiKeyRequired"));
}
const openai = new OpenAI({
apiKey: key,
baseURL: baseUrl,
});
logger.debug("baseURL", openai.baseURL);
const file = await openai.audio.speech.create({
input: text,
model,
voice,
});
const buffer = Buffer.from(await file.arrayBuffer());
await fs.outputFile(filePath, buffer);
openaiConfig = {
apiKey: defaultConfig.key,
baseURL: baseUrl || defaultConfig.baseUrl,
};
}
const openai = new OpenAI(openaiConfig);
const file = await openai.audio.speech.create({
input: text,
model,
voice,
});
const buffer = Buffer.from(await file.arrayBuffer());
await fs.outputFile(filePath, buffer);
const md5 = await hashFile(filePath, { algo: "md5" });
fs.renameSync(

View File

@@ -24,11 +24,6 @@ import path from "path";
import fs from "fs-extra";
const logger = log.scope("db/models/transcription");
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
@Table({
modelName: "Transcription",
@@ -82,6 +77,11 @@ export class Transcription extends Model<Transcription> {
async sync() {
if (this.getDataValue("state") !== "finished") return;
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
return webApi.syncTranscription(this.toJSON()).then(() => {
this.update({ syncedAt: new Date() });
});

View File

@@ -34,12 +34,6 @@ const SIZE_LIMIT = 1024 * 1024 * 100; // 100MB
const logger = log.scope("db/models/video");
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
@Table({
modelName: "Video",
tableName: "videos",
@@ -131,6 +125,11 @@ export class Video extends Model<Video> {
)}`;
}
@Column(DataType.VIRTUAL)
get duration(): number {
return this.getDataValue("metadata").duration;
}
get extname(): string {
return (
this.getDataValue("metadata").extname ||
@@ -189,9 +188,13 @@ export class Video extends Model<Video> {
}
async sync() {
if (!this.isUploaded) {
this.upload();
}
if (this.isSynced) return;
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
return webApi.syncVideo(this.toJSON()).then(() => {
this.update({ syncedAt: new Date() });
@@ -235,6 +238,7 @@ export class Video extends Model<Video> {
@AfterUpdate
static notifyForUpdate(video: Video) {
this.notify(video, "update");
video.sync().catch(() => {});
}
@AfterDestroy

View File

@@ -167,6 +167,14 @@ export default {
ipcMain.handle("settings-switch-language", (_event, language) => {
switchLanguage(language);
});
ipcMain.handle("settings-get-default-engine", (_event) => {
return settings.getSync("defaultEngine");
});
ipcMain.handle("settings-set-default-engine", (_event, engine) => {
return settings.setSync("defaultEngine", engine);
});
},
cachePath,
libraryPath,

View File

@@ -19,12 +19,6 @@ import { sortedUniqBy, take } from "lodash";
const logger = log.scope("whisper");
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
const MAGIC_TOKENS = ["Mrs.", "Ms.", "Mr.", "Dr.", "Prof.", "St."];
const END_OF_WORD_REGEX = /[^\.!,\?][\.!\?]/g;
class Whipser {
@@ -200,6 +194,11 @@ class Whipser {
}
async transcribeFromAzure(file: string): Promise<Partial<WhisperOutputType>> {
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
});
const { token, region } = await webApi.generateSpeechToken();
const sdk = new AzureSpeechSdk(token, region);