From 93d9e190e240cf27b08f6760cd03617e01978836 Mon Sep 17 00:00:00 2001 From: an-lee Date: Wed, 17 Jul 2024 17:25:42 +0800 Subject: [PATCH] Refactor pronunciation assess (#820) * refactor pronunciation assessment request * remove unused code --- .../pronunciation-assessments-handler.ts | 29 ++- enjoy/src/main/db/models/recording.ts | 65 ------ enjoy/src/preload.ts | 3 - .../pronunciation-assessment-form.tsx | 73 ++++--- .../recordings/recording-detail.tsx | 11 +- enjoy/src/renderer/context/index.ts | 1 - .../renderer/context/wavesurfer-provider.tsx | 185 ------------------ .../hooks/use-pronunciation-assessments.tsx | 77 +++++++- enjoy/src/types/enjoy-app.d.ts | 1 - enjoy/src/types/recording.d.ts | 1 + 10 files changed, 141 insertions(+), 305 deletions(-) delete mode 100644 enjoy/src/renderer/context/wavesurfer-provider.tsx diff --git a/enjoy/src/main/db/handlers/pronunciation-assessments-handler.ts b/enjoy/src/main/db/handlers/pronunciation-assessments-handler.ts index 00c3ed2e..0e426fa9 100644 --- a/enjoy/src/main/db/handlers/pronunciation-assessments-handler.ts +++ b/enjoy/src/main/db/handlers/pronunciation-assessments-handler.ts @@ -47,27 +47,22 @@ class PronunciationAssessmentsHandler { private async create( _event: IpcMainEvent, - data: Partial> & { - blob: { - type: string; - arrayBuffer: ArrayBuffer; - }; - } + data: Partial> ) { - const recording = await Recording.createFromBlob(data.blob, { - targetId: "00000000-0000-0000-0000-000000000000", - targetType: "None", - referenceText: data.referenceText, - language: data.language, + const { targetId, targetType } = data; + const existed = await PronunciationAssessment.findOne({ + where: { + targetId, + targetType, + }, }); - try { - const assessment = await recording.assess(data.language); - return assessment.toJSON(); - } catch (error) { - await recording.destroy(); - throw error; + if (existed) { + return existed.toJSON(); } + + const assessment = await PronunciationAssessment.create(data); + return assessment.toJSON(); } private async update( diff --git a/enjoy/src/main/db/models/recording.ts b/enjoy/src/main/db/models/recording.ts index d55161ea..3ddbc610 100644 --- a/enjoy/src/main/db/models/recording.ts +++ b/enjoy/src/main/db/models/recording.ts @@ -24,9 +24,7 @@ import log from "@main/logger"; import storage from "@main/storage"; import { Client } from "@/api"; import { WEB_API_URL } from "@/constants"; -import { AzureSpeechSdk } from "@main/azure-speech-sdk"; import echogarden from "@main/echogarden"; -import camelcaseKeys from "camelcase-keys"; import { t } from "i18next"; import { Attributes } from "sequelize"; import { v5 as uuidv5 } from "uuid"; @@ -156,69 +154,6 @@ export class Recording extends Model { }); } - async assess(language?: string) { - const assessment = await PronunciationAssessment.findOne({ - where: { targetId: this.id, targetType: "Recording" }, - }); - - if (assessment) { - return assessment; - } - - await this.sync(); - const webApi = new Client({ - baseUrl: process.env.WEB_API_URL || WEB_API_URL, - accessToken: settings.getSync("user.accessToken") as string, - logger, - }); - - const { - id: tokenId, - token, - region, - } = await webApi.generateSpeechToken({ - targetId: this.id, - targetType: "Recording", - }); - const sdk = new AzureSpeechSdk(token, region); - - const result = await sdk.pronunciationAssessment({ - filePath: this.filePath, - reference: this.referenceText, - language: language || this.language, - }); - - const resultJson = camelcaseKeys( - JSON.parse(JSON.stringify(result.detailResult)), - { - deep: true, - } - ); - resultJson.duration = this.duration; - resultJson.tokenId = tokenId; - - const _pronunciationAssessment = await PronunciationAssessment.create( - { - targetId: this.id, - targetType: "Recording", - pronunciationScore: result.pronunciationScore, - accuracyScore: result.accuracyScore, - completenessScore: result.completenessScore, - fluencyScore: result.fluencyScore, - prosodyScore: result.prosodyScore, - grammarScore: result.contentAssessmentResult?.grammarScore, - vocabularyScore: result.contentAssessmentResult?.vocabularyScore, - topicScore: result.contentAssessmentResult?.topicScore, - result: resultJson, - language: language || this.language, - }, - { - include: Recording, - } - ); - return _pronunciationAssessment; - } - @AfterFind static async findTarget(findResult: Recording | Recording[]) { if (!findResult) return; diff --git a/enjoy/src/preload.ts b/enjoy/src/preload.ts index f616440a..0cb10864 100644 --- a/enjoy/src/preload.ts +++ b/enjoy/src/preload.ts @@ -327,9 +327,6 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", { upload: (id: string) => { return ipcRenderer.invoke("recordings-upload", id); }, - assess: (id: string, language?: string) => { - return ipcRenderer.invoke("recordings-assess", id, language); - }, stats: (params: { from: string; to: string }) => { return ipcRenderer.invoke("recordings-stats", params); }, diff --git a/enjoy/src/renderer/components/pronunciation-assessments/pronunciation-assessment-form.tsx b/enjoy/src/renderer/components/pronunciation-assessments/pronunciation-assessment-form.tsx index 8f5d0029..37d8f557 100644 --- a/enjoy/src/renderer/components/pronunciation-assessments/pronunciation-assessment-form.tsx +++ b/enjoy/src/renderer/components/pronunciation-assessments/pronunciation-assessment-form.tsx @@ -29,10 +29,11 @@ import { zodResolver } from "@hookform/resolvers/zod"; import { LoaderIcon, MicIcon, SquareIcon } from "lucide-react"; import WaveSurfer from "wavesurfer.js"; import RecordPlugin from "wavesurfer.js/dist/plugins/record"; +import { usePronunciationAssessments } from "@/renderer/hooks"; const pronunciationAssessmentSchema = z.object({ file: z.instanceof(FileList).optional(), - recording: z.instanceof(Blob).optional(), + recordingFile: z.instanceof(Blob).optional(), language: z.string().min(2), referenceText: z.string().optional(), }); @@ -41,6 +42,7 @@ export const PronunciationAssessmentForm = () => { const navigate = useNavigate(); const { EnjoyApp, learningLanguage } = useContext(AppSettingsProviderContext); const [submitting, setSubmitting] = useState(false); + const { createAssessment } = usePronunciationAssessments(); const form = useForm>({ resolver: zodResolver(pronunciationAssessmentSchema), @@ -55,34 +57,28 @@ export const PronunciationAssessmentForm = () => { const onSubmit = async ( data: z.infer ) => { - console.log(data); - if ((!data.file || data.file.length === 0) && !data.recording) { + if ((!data.file || data.file.length === 0) && !data.recordingFile) { toast.error(t("noFileOrRecording")); - form.setError("recording", { message: t("noFileOrRecording") }); + form.setError("recordingFile", { message: t("noFileOrRecording") }); return; } - const { language, referenceText, file, recording } = data; - let arrayBuffer: ArrayBuffer; - if (recording) { - arrayBuffer = await recording.arrayBuffer(); - } else { - arrayBuffer = await new Blob([file[0]]).arrayBuffer(); - } + const { language, referenceText } = data; + + const recording = await createRecording(data); setSubmitting(true); toast.promise( - EnjoyApp.pronunciationAssessments - .create({ - language, - referenceText, - blob: { - type: recording?.type || file[0].type, - arrayBuffer, - }, - }) + createAssessment({ + language, + reference: referenceText, + recording, + }) .then(() => { navigate("/pronunciation_assessments"); }) + .catch(() => { + EnjoyApp.recordings.destroy(recording.id); + }) .finally(() => setSubmitting(false)), { loading: t("assessing"), @@ -92,6 +88,35 @@ export const PronunciationAssessmentForm = () => { ); }; + const createRecording = async ( + data: z.infer + ): Promise => { + const { language, referenceText, file, recordingFile } = data; + let arrayBuffer: ArrayBuffer; + if (recordingFile) { + arrayBuffer = await recordingFile.arrayBuffer(); + } else { + arrayBuffer = await new Blob([file[0]]).arrayBuffer(); + } + const recording = await EnjoyApp.recordings.create({ + language, + referenceText, + blob: { + type: recordingFile?.type || file[0].type, + arrayBuffer, + }, + }); + + try { + await EnjoyApp.recordings.sync(recording.id); + return recording; + } catch (err) { + toast.error(err.message); + EnjoyApp.recordings.destroy(recording.id); + return; + } + }; + return (
@@ -128,7 +153,7 @@ export const PronunciationAssessmentForm = () => {
( { /> { - form.resetField("recording"); + form.resetField("recordingFile"); }} onFinish={(blob) => { field.onChange(blob); @@ -149,11 +174,11 @@ export const PronunciationAssessmentForm = () => { )} /> - {form.watch("recording") && ( + {form.watch("recordingFile") && (
diff --git a/enjoy/src/renderer/components/recordings/recording-detail.tsx b/enjoy/src/renderer/components/recordings/recording-detail.tsx index a54cc3c3..a640ff83 100644 --- a/enjoy/src/renderer/components/recordings/recording-detail.tsx +++ b/enjoy/src/renderer/components/recordings/recording-detail.tsx @@ -7,6 +7,7 @@ import { Separator, ScrollArea, toast } from "@renderer/components/ui"; import { useState, useContext, useEffect } from "react"; import { AppSettingsProviderContext } from "@renderer/context"; import { Tooltip } from "react-tooltip"; +import { usePronunciationAssessments } from "@renderer/hooks"; export const RecordingDetail = (props: { recording: RecordingType; @@ -25,7 +26,8 @@ export const RecordingDetail = (props: { }>(); const [isPlaying, setIsPlaying] = useState(false); - const { EnjoyApp, learningLanguage } = useContext(AppSettingsProviderContext); + const { learningLanguage } = useContext(AppSettingsProviderContext); + const { createAssessment } = usePronunciationAssessments(); const [assessing, setAssessing] = useState(false); const assess = () => { @@ -33,8 +35,11 @@ export const RecordingDetail = (props: { if (result) return; setAssessing(true); - EnjoyApp.recordings - .assess(recording.id, learningLanguage) + createAssessment({ + recording, + reference: recording.referenceText, + language: recording.language || learningLanguage, + }) .catch((err) => { toast.error(err.message); }) diff --git a/enjoy/src/renderer/context/index.ts b/enjoy/src/renderer/context/index.ts index e26ccf69..caf8dec3 100644 --- a/enjoy/src/renderer/context/index.ts +++ b/enjoy/src/renderer/context/index.ts @@ -5,4 +5,3 @@ export * from "./db-provider"; export * from './hotkeys-settings-provider' export * from "./media-player-provider"; export * from "./theme-provider"; -export * from "./wavesurfer-provider"; diff --git a/enjoy/src/renderer/context/wavesurfer-provider.tsx b/enjoy/src/renderer/context/wavesurfer-provider.tsx deleted file mode 100644 index a31393d1..00000000 --- a/enjoy/src/renderer/context/wavesurfer-provider.tsx +++ /dev/null @@ -1,185 +0,0 @@ -import { createContext, useEffect, useState, useContext } from "react"; -import { extractFrequencies } from "@/utils"; -import { AppSettingsProviderContext } from "@renderer/context"; -import WaveSurfer from "wavesurfer.js"; -import Regions, { - type Region as RegionType, -} from "wavesurfer.js/dist/plugins/regions"; - -type WavesurferContextType = { - media: AudioType | VideoType; - setMedia: (media: AudioType | VideoType) => void; - setMediaProvider: (mediaProvider: HTMLAudioElement | null) => void; - wavesurfer: WaveSurfer; - setRef: (ref: any) => void; - initialized: boolean; - currentTime: number; - currentSegmentIndex: number; - setCurrentSegmentIndex: (index: number) => void; - zoomRatio: number; -}; - -export const WavesurferContext = createContext(null); - -export const WavesurferProvider = ({ - children, -}: { - children: React.ReactNode; -}) => { - const { EnjoyApp } = useContext(AppSettingsProviderContext); - - const [media, setMedia] = useState(null); - const [mediaProvider, setMediaProvider] = useState( - null - ); - const [wavesurfer, setWavesurfer] = useState(null); - const [regions, setRegions] = useState(null); - const [ref, setRef] = useState(null); - - // Player state - const [initialized, setInitialized] = useState(false); - const [currentTime, setCurrentTime] = useState(0); - const [seek, setSeek] = useState<{ - seekTo: number; - timestamp: number; - }>(); - const [currentSegmentIndex, setCurrentSegmentIndex] = useState(0); - const [zoomRatio, setZoomRatio] = useState(1.0); - const [isPlaying, setIsPlaying] = useState(false); - const [playMode, setPlayMode] = useState<"loop" | "single" | "all">("all"); - const [playBackRate, setPlaybackRate] = useState(1); - const [displayInlineCaption, setDisplayInlineCaption] = - useState(true); - - const initializeWavesurfer = async () => { - if (!media) return; - if (!mediaProvider) return; - if (!ref.current) return; - - const waveform = await EnjoyApp.waveforms.find(media.md5); - const ws = WaveSurfer.create({ - container: ref.current, - height: 250, - waveColor: "#eee", - progressColor: "rgba(0, 0, 0, 0.15)", - cursorColor: "#aaa", - barWidth: 2, - autoScroll: true, - minPxPerSec: 150, - autoCenter: false, - dragToSeek: false, - media: mediaProvider, - peaks: waveform ? [waveform.peaks] : undefined, - duration: waveform ? waveform.duration : undefined, - }); - - const blob = await fetch(media.src).then((res) => res.blob()); - - if (waveform) { - ws.loadBlob(blob, [waveform.peaks], waveform.duration); - setInitialized(true); - } else { - ws.loadBlob(blob); - } - - // Set up region plugin - setRegions(ws.registerPlugin(Regions.create())); - - setWavesurfer(ws); - }; - - /* - * Initialize wavesurfer when container ref is available - * and mediaProvider is available - */ - useEffect(() => { - initializeWavesurfer(); - }, [media, ref, mediaProvider]); - - /* - * When wavesurfer is initialized, - * set up event listeners for wavesurfer - * and clean up when component is unmounted - */ - useEffect(() => { - if (!wavesurfer) return; - - setCurrentTime(0); - setIsPlaying(false); - - const subscriptions = [ - wavesurfer.on("play", () => setIsPlaying(true)), - wavesurfer.on("pause", () => setIsPlaying(false)), - wavesurfer.on("loading", (percent: number) => console.log(`${percent}%`)), - wavesurfer.on("timeupdate", (time: number) => setCurrentTime(time)), - wavesurfer.on("decode", () => { - const peaks: Float32Array = wavesurfer - .getDecodedData() - .getChannelData(0); - const duration: number = wavesurfer.getDuration(); - const sampleRate = wavesurfer.options.sampleRate; - const _frequencies = extractFrequencies({ peaks, sampleRate }); - const _waveform = { - peaks: Array.from(peaks), - duration, - sampleRate, - frequencies: _frequencies, - }; - EnjoyApp.waveforms.save(media.md5, _waveform); - }), - wavesurfer.on("ready", () => { - setInitialized(true); - }), - ]; - - return () => { - subscriptions.forEach((unsub) => unsub()); - }; - }, [wavesurfer]); - - /* - * When regions are available, - * set up event listeners for regions - * and clean up when component is unmounted - */ - useEffect(() => { - if (!regions) return; - - const subscriptions = [ - wavesurfer.on("finish", () => { - if (playMode !== "loop") return; - - regions?.getRegions()[0]?.play(); - }), - - regions.on("region-created", (region: RegionType) => { - region.on("click", () => { - wavesurfer.play(region.start, region.end); - }); - }), - ]; - - return () => { - subscriptions.forEach((unsub) => unsub()); - }; - }); - - return ( - - {children} - - ); -}; diff --git a/enjoy/src/renderer/hooks/use-pronunciation-assessments.tsx b/enjoy/src/renderer/hooks/use-pronunciation-assessments.tsx index 706c6efa..03171e78 100644 --- a/enjoy/src/renderer/hooks/use-pronunciation-assessments.tsx +++ b/enjoy/src/renderer/hooks/use-pronunciation-assessments.tsx @@ -1,9 +1,73 @@ import * as sdk from "microsoft-cognitiveservices-speech-sdk"; import { useContext } from "react"; import { AppSettingsProviderContext } from "@renderer/context"; +import camelcaseKeys from "camelcase-keys"; export const usePronunciationAssessments = () => { - const { webApi } = useContext(AppSettingsProviderContext); + const { webApi, EnjoyApp } = useContext(AppSettingsProviderContext); + + const createAssessment = async (params: { + language: string; + recording: RecordingType; + reference?: string; + targetId?: string; + targetType?: string; + }) => { + let { recording, targetId, targetType } = params; + if (targetId && targetType && !recording) { + recording = await EnjoyApp.recordings.findOne(targetId); + } + + await EnjoyApp.recordings.sync(recording.id); + const blob = await (await fetch(recording.src)).blob(); + targetId = recording.id; + targetType = "Recording"; + + const { language, reference = recording.referenceText } = params; + + const { + id: tokenId, + token, + region, + } = await webApi.generateSpeechToken({ + purpose: "pronunciation_assessment", + targetId, + targetType, + }); + + const result = await assess( + { + blob, + language, + reference, + }, + { token, region } + ); + + const resultJson = camelcaseKeys( + JSON.parse(JSON.stringify(result.detailResult)), + { + deep: true, + } + ); + resultJson.tokenId = tokenId; + resultJson.duration = recording?.duration; + + return EnjoyApp.pronunciationAssessments.create({ + targetId: recording.id, + targetType: "Recording", + pronunciationScore: result.pronunciationScore, + accuracyScore: result.accuracyScore, + completenessScore: result.completenessScore, + fluencyScore: result.fluencyScore, + prosodyScore: result.prosodyScore, + grammarScore: result.contentAssessmentResult?.grammarScore, + vocabularyScore: result.contentAssessmentResult?.vocabularyScore, + topicScore: result.contentAssessmentResult?.topicScore, + result: resultJson, + language: params.language || recording.language, + }); + }; const assess = async ( params: { @@ -11,13 +75,13 @@ export const usePronunciationAssessments = () => { language: string; reference?: string; }, - options?: { - targetId?: string; - targetType?: string; + options: { + token: string; + region: string; } - ) => { + ): Promise => { const { blob, language, reference } = params; - const { id, token, region } = await webApi.generateSpeechToken(options); + const { token, region } = options; const config = sdk.SpeechConfig.fromAuthorizationToken(token, region); const audioConfig = sdk.AudioConfig.fromWavFileInput( new File([blob], "audio.wav") @@ -74,6 +138,7 @@ export const usePronunciationAssessments = () => { }; return { + createAssessment, assess, }; }; diff --git a/enjoy/src/types/enjoy-app.d.ts b/enjoy/src/types/enjoy-app.d.ts index 75cb9345..0666e6c7 100644 --- a/enjoy/src/types/enjoy-app.d.ts +++ b/enjoy/src/types/enjoy-app.d.ts @@ -178,7 +178,6 @@ type EnjoyAppType = { update: (id: string, params: any) => Promise; destroy: (id: string) => Promise; upload: (id: string) => Promise; - assess: (id: string, language?: string) => Promise; stats: (params: { from: string; to: string }) => Promise<{ count: number; duration: number; diff --git a/enjoy/src/types/recording.d.ts b/enjoy/src/types/recording.d.ts index 3894c7eb..a3a7b245 100644 --- a/enjoy/src/types/recording.d.ts +++ b/enjoy/src/types/recording.d.ts @@ -4,6 +4,7 @@ type RecordingType = { target?: AudioType | (MessageType & any); targetId: string; targetType: string; + language?: string; pronunciationAssessment?: PronunciationAssessmentType & any; referenceId: number; referenceText?: string;