Refactor pronunciation assess (#820)

* refactor pronunciation assessment request

* remove unused code
This commit is contained in:
an-lee
2024-07-17 17:25:42 +08:00
committed by GitHub
parent ff4fcad32f
commit 93d9e190e2
10 changed files with 141 additions and 305 deletions

View File

@@ -47,27 +47,22 @@ class PronunciationAssessmentsHandler {
private async create(
_event: IpcMainEvent,
data: Partial<Attributes<PronunciationAssessment>> & {
blob: {
type: string;
arrayBuffer: ArrayBuffer;
};
}
data: Partial<Attributes<PronunciationAssessment>>
) {
const recording = await Recording.createFromBlob(data.blob, {
targetId: "00000000-0000-0000-0000-000000000000",
targetType: "None",
referenceText: data.referenceText,
language: data.language,
const { targetId, targetType } = data;
const existed = await PronunciationAssessment.findOne({
where: {
targetId,
targetType,
},
});
try {
const assessment = await recording.assess(data.language);
return assessment.toJSON();
} catch (error) {
await recording.destroy();
throw error;
if (existed) {
return existed.toJSON();
}
const assessment = await PronunciationAssessment.create(data);
return assessment.toJSON();
}
private async update(

View File

@@ -24,9 +24,7 @@ import log from "@main/logger";
import storage from "@main/storage";
import { Client } from "@/api";
import { WEB_API_URL } from "@/constants";
import { AzureSpeechSdk } from "@main/azure-speech-sdk";
import echogarden from "@main/echogarden";
import camelcaseKeys from "camelcase-keys";
import { t } from "i18next";
import { Attributes } from "sequelize";
import { v5 as uuidv5 } from "uuid";
@@ -156,69 +154,6 @@ export class Recording extends Model<Recording> {
});
}
async assess(language?: string) {
const assessment = await PronunciationAssessment.findOne({
where: { targetId: this.id, targetType: "Recording" },
});
if (assessment) {
return assessment;
}
await this.sync();
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger,
});
const {
id: tokenId,
token,
region,
} = await webApi.generateSpeechToken({
targetId: this.id,
targetType: "Recording",
});
const sdk = new AzureSpeechSdk(token, region);
const result = await sdk.pronunciationAssessment({
filePath: this.filePath,
reference: this.referenceText,
language: language || this.language,
});
const resultJson = camelcaseKeys(
JSON.parse(JSON.stringify(result.detailResult)),
{
deep: true,
}
);
resultJson.duration = this.duration;
resultJson.tokenId = tokenId;
const _pronunciationAssessment = await PronunciationAssessment.create(
{
targetId: this.id,
targetType: "Recording",
pronunciationScore: result.pronunciationScore,
accuracyScore: result.accuracyScore,
completenessScore: result.completenessScore,
fluencyScore: result.fluencyScore,
prosodyScore: result.prosodyScore,
grammarScore: result.contentAssessmentResult?.grammarScore,
vocabularyScore: result.contentAssessmentResult?.vocabularyScore,
topicScore: result.contentAssessmentResult?.topicScore,
result: resultJson,
language: language || this.language,
},
{
include: Recording,
}
);
return _pronunciationAssessment;
}
@AfterFind
static async findTarget(findResult: Recording | Recording[]) {
if (!findResult) return;

View File

@@ -327,9 +327,6 @@ contextBridge.exposeInMainWorld("__ENJOY_APP__", {
upload: (id: string) => {
return ipcRenderer.invoke("recordings-upload", id);
},
assess: (id: string, language?: string) => {
return ipcRenderer.invoke("recordings-assess", id, language);
},
stats: (params: { from: string; to: string }) => {
return ipcRenderer.invoke("recordings-stats", params);
},

View File

@@ -29,10 +29,11 @@ import { zodResolver } from "@hookform/resolvers/zod";
import { LoaderIcon, MicIcon, SquareIcon } from "lucide-react";
import WaveSurfer from "wavesurfer.js";
import RecordPlugin from "wavesurfer.js/dist/plugins/record";
import { usePronunciationAssessments } from "@/renderer/hooks";
const pronunciationAssessmentSchema = z.object({
file: z.instanceof(FileList).optional(),
recording: z.instanceof(Blob).optional(),
recordingFile: z.instanceof(Blob).optional(),
language: z.string().min(2),
referenceText: z.string().optional(),
});
@@ -41,6 +42,7 @@ export const PronunciationAssessmentForm = () => {
const navigate = useNavigate();
const { EnjoyApp, learningLanguage } = useContext(AppSettingsProviderContext);
const [submitting, setSubmitting] = useState(false);
const { createAssessment } = usePronunciationAssessments();
const form = useForm<z.infer<typeof pronunciationAssessmentSchema>>({
resolver: zodResolver(pronunciationAssessmentSchema),
@@ -55,34 +57,28 @@ export const PronunciationAssessmentForm = () => {
const onSubmit = async (
data: z.infer<typeof pronunciationAssessmentSchema>
) => {
console.log(data);
if ((!data.file || data.file.length === 0) && !data.recording) {
if ((!data.file || data.file.length === 0) && !data.recordingFile) {
toast.error(t("noFileOrRecording"));
form.setError("recording", { message: t("noFileOrRecording") });
form.setError("recordingFile", { message: t("noFileOrRecording") });
return;
}
const { language, referenceText, file, recording } = data;
let arrayBuffer: ArrayBuffer;
if (recording) {
arrayBuffer = await recording.arrayBuffer();
} else {
arrayBuffer = await new Blob([file[0]]).arrayBuffer();
}
const { language, referenceText } = data;
const recording = await createRecording(data);
setSubmitting(true);
toast.promise(
EnjoyApp.pronunciationAssessments
.create({
language,
referenceText,
blob: {
type: recording?.type || file[0].type,
arrayBuffer,
},
})
createAssessment({
language,
reference: referenceText,
recording,
})
.then(() => {
navigate("/pronunciation_assessments");
})
.catch(() => {
EnjoyApp.recordings.destroy(recording.id);
})
.finally(() => setSubmitting(false)),
{
loading: t("assessing"),
@@ -92,6 +88,35 @@ export const PronunciationAssessmentForm = () => {
);
};
const createRecording = async (
data: z.infer<typeof pronunciationAssessmentSchema>
): Promise<RecordingType> => {
const { language, referenceText, file, recordingFile } = data;
let arrayBuffer: ArrayBuffer;
if (recordingFile) {
arrayBuffer = await recordingFile.arrayBuffer();
} else {
arrayBuffer = await new Blob([file[0]]).arrayBuffer();
}
const recording = await EnjoyApp.recordings.create({
language,
referenceText,
blob: {
type: recordingFile?.type || file[0].type,
arrayBuffer,
},
});
try {
await EnjoyApp.recordings.sync(recording.id);
return recording;
} catch (err) {
toast.error(err.message);
EnjoyApp.recordings.destroy(recording.id);
return;
}
};
return (
<div className="max-w-screen-md mx-auto">
<Form {...form}>
@@ -128,7 +153,7 @@ export const PronunciationAssessmentForm = () => {
<div className="grid gap-4 border p-4 rounded-lg">
<FormField
control={form.control}
name="recording"
name="recordingFile"
render={({ field }) => (
<FormItem className="grid w-full items-center gap-1.5">
<Input
@@ -140,7 +165,7 @@ export const PronunciationAssessmentForm = () => {
/>
<RecorderButton
onStart={() => {
form.resetField("recording");
form.resetField("recordingFile");
}}
onFinish={(blob) => {
field.onChange(blob);
@@ -149,11 +174,11 @@ export const PronunciationAssessmentForm = () => {
</FormItem>
)}
/>
{form.watch("recording") && (
{form.watch("recordingFile") && (
<div className="">
<audio controls className="w-full">
<source
src={URL.createObjectURL(form.watch("recording"))}
src={URL.createObjectURL(form.watch("recordingFile"))}
/>
</audio>
</div>

View File

@@ -7,6 +7,7 @@ import { Separator, ScrollArea, toast } from "@renderer/components/ui";
import { useState, useContext, useEffect } from "react";
import { AppSettingsProviderContext } from "@renderer/context";
import { Tooltip } from "react-tooltip";
import { usePronunciationAssessments } from "@renderer/hooks";
export const RecordingDetail = (props: {
recording: RecordingType;
@@ -25,7 +26,8 @@ export const RecordingDetail = (props: {
}>();
const [isPlaying, setIsPlaying] = useState(false);
const { EnjoyApp, learningLanguage } = useContext(AppSettingsProviderContext);
const { learningLanguage } = useContext(AppSettingsProviderContext);
const { createAssessment } = usePronunciationAssessments();
const [assessing, setAssessing] = useState(false);
const assess = () => {
@@ -33,8 +35,11 @@ export const RecordingDetail = (props: {
if (result) return;
setAssessing(true);
EnjoyApp.recordings
.assess(recording.id, learningLanguage)
createAssessment({
recording,
reference: recording.referenceText,
language: recording.language || learningLanguage,
})
.catch((err) => {
toast.error(err.message);
})

View File

@@ -5,4 +5,3 @@ export * from "./db-provider";
export * from './hotkeys-settings-provider'
export * from "./media-player-provider";
export * from "./theme-provider";
export * from "./wavesurfer-provider";

View File

@@ -1,185 +0,0 @@
import { createContext, useEffect, useState, useContext } from "react";
import { extractFrequencies } from "@/utils";
import { AppSettingsProviderContext } from "@renderer/context";
import WaveSurfer from "wavesurfer.js";
import Regions, {
type Region as RegionType,
} from "wavesurfer.js/dist/plugins/regions";
type WavesurferContextType = {
media: AudioType | VideoType;
setMedia: (media: AudioType | VideoType) => void;
setMediaProvider: (mediaProvider: HTMLAudioElement | null) => void;
wavesurfer: WaveSurfer;
setRef: (ref: any) => void;
initialized: boolean;
currentTime: number;
currentSegmentIndex: number;
setCurrentSegmentIndex: (index: number) => void;
zoomRatio: number;
};
export const WavesurferContext = createContext<WavesurferContextType>(null);
export const WavesurferProvider = ({
children,
}: {
children: React.ReactNode;
}) => {
const { EnjoyApp } = useContext(AppSettingsProviderContext);
const [media, setMedia] = useState<AudioType | VideoType>(null);
const [mediaProvider, setMediaProvider] = useState<HTMLAudioElement | null>(
null
);
const [wavesurfer, setWavesurfer] = useState(null);
const [regions, setRegions] = useState<Regions | null>(null);
const [ref, setRef] = useState(null);
// Player state
const [initialized, setInitialized] = useState<boolean>(false);
const [currentTime, setCurrentTime] = useState<number>(0);
const [seek, setSeek] = useState<{
seekTo: number;
timestamp: number;
}>();
const [currentSegmentIndex, setCurrentSegmentIndex] = useState<number>(0);
const [zoomRatio, setZoomRatio] = useState<number>(1.0);
const [isPlaying, setIsPlaying] = useState(false);
const [playMode, setPlayMode] = useState<"loop" | "single" | "all">("all");
const [playBackRate, setPlaybackRate] = useState<number>(1);
const [displayInlineCaption, setDisplayInlineCaption] =
useState<boolean>(true);
const initializeWavesurfer = async () => {
if (!media) return;
if (!mediaProvider) return;
if (!ref.current) return;
const waveform = await EnjoyApp.waveforms.find(media.md5);
const ws = WaveSurfer.create({
container: ref.current,
height: 250,
waveColor: "#eee",
progressColor: "rgba(0, 0, 0, 0.15)",
cursorColor: "#aaa",
barWidth: 2,
autoScroll: true,
minPxPerSec: 150,
autoCenter: false,
dragToSeek: false,
media: mediaProvider,
peaks: waveform ? [waveform.peaks] : undefined,
duration: waveform ? waveform.duration : undefined,
});
const blob = await fetch(media.src).then((res) => res.blob());
if (waveform) {
ws.loadBlob(blob, [waveform.peaks], waveform.duration);
setInitialized(true);
} else {
ws.loadBlob(blob);
}
// Set up region plugin
setRegions(ws.registerPlugin(Regions.create()));
setWavesurfer(ws);
};
/*
* Initialize wavesurfer when container ref is available
* and mediaProvider is available
*/
useEffect(() => {
initializeWavesurfer();
}, [media, ref, mediaProvider]);
/*
* When wavesurfer is initialized,
* set up event listeners for wavesurfer
* and clean up when component is unmounted
*/
useEffect(() => {
if (!wavesurfer) return;
setCurrentTime(0);
setIsPlaying(false);
const subscriptions = [
wavesurfer.on("play", () => setIsPlaying(true)),
wavesurfer.on("pause", () => setIsPlaying(false)),
wavesurfer.on("loading", (percent: number) => console.log(`${percent}%`)),
wavesurfer.on("timeupdate", (time: number) => setCurrentTime(time)),
wavesurfer.on("decode", () => {
const peaks: Float32Array = wavesurfer
.getDecodedData()
.getChannelData(0);
const duration: number = wavesurfer.getDuration();
const sampleRate = wavesurfer.options.sampleRate;
const _frequencies = extractFrequencies({ peaks, sampleRate });
const _waveform = {
peaks: Array.from(peaks),
duration,
sampleRate,
frequencies: _frequencies,
};
EnjoyApp.waveforms.save(media.md5, _waveform);
}),
wavesurfer.on("ready", () => {
setInitialized(true);
}),
];
return () => {
subscriptions.forEach((unsub) => unsub());
};
}, [wavesurfer]);
/*
* When regions are available,
* set up event listeners for regions
* and clean up when component is unmounted
*/
useEffect(() => {
if (!regions) return;
const subscriptions = [
wavesurfer.on("finish", () => {
if (playMode !== "loop") return;
regions?.getRegions()[0]?.play();
}),
regions.on("region-created", (region: RegionType) => {
region.on("click", () => {
wavesurfer.play(region.start, region.end);
});
}),
];
return () => {
subscriptions.forEach((unsub) => unsub());
};
});
return (
<WavesurferContext.Provider
value={{
media,
setMedia,
setMediaProvider,
wavesurfer,
setRef,
initialized,
currentTime,
currentSegmentIndex,
setCurrentSegmentIndex,
zoomRatio,
}}
>
{children}
</WavesurferContext.Provider>
);
};

View File

@@ -1,9 +1,73 @@
import * as sdk from "microsoft-cognitiveservices-speech-sdk";
import { useContext } from "react";
import { AppSettingsProviderContext } from "@renderer/context";
import camelcaseKeys from "camelcase-keys";
export const usePronunciationAssessments = () => {
const { webApi } = useContext(AppSettingsProviderContext);
const { webApi, EnjoyApp } = useContext(AppSettingsProviderContext);
const createAssessment = async (params: {
language: string;
recording: RecordingType;
reference?: string;
targetId?: string;
targetType?: string;
}) => {
let { recording, targetId, targetType } = params;
if (targetId && targetType && !recording) {
recording = await EnjoyApp.recordings.findOne(targetId);
}
await EnjoyApp.recordings.sync(recording.id);
const blob = await (await fetch(recording.src)).blob();
targetId = recording.id;
targetType = "Recording";
const { language, reference = recording.referenceText } = params;
const {
id: tokenId,
token,
region,
} = await webApi.generateSpeechToken({
purpose: "pronunciation_assessment",
targetId,
targetType,
});
const result = await assess(
{
blob,
language,
reference,
},
{ token, region }
);
const resultJson = camelcaseKeys(
JSON.parse(JSON.stringify(result.detailResult)),
{
deep: true,
}
);
resultJson.tokenId = tokenId;
resultJson.duration = recording?.duration;
return EnjoyApp.pronunciationAssessments.create({
targetId: recording.id,
targetType: "Recording",
pronunciationScore: result.pronunciationScore,
accuracyScore: result.accuracyScore,
completenessScore: result.completenessScore,
fluencyScore: result.fluencyScore,
prosodyScore: result.prosodyScore,
grammarScore: result.contentAssessmentResult?.grammarScore,
vocabularyScore: result.contentAssessmentResult?.vocabularyScore,
topicScore: result.contentAssessmentResult?.topicScore,
result: resultJson,
language: params.language || recording.language,
});
};
const assess = async (
params: {
@@ -11,13 +75,13 @@ export const usePronunciationAssessments = () => {
language: string;
reference?: string;
},
options?: {
targetId?: string;
targetType?: string;
options: {
token: string;
region: string;
}
) => {
): Promise<sdk.PronunciationAssessmentResult> => {
const { blob, language, reference } = params;
const { id, token, region } = await webApi.generateSpeechToken(options);
const { token, region } = options;
const config = sdk.SpeechConfig.fromAuthorizationToken(token, region);
const audioConfig = sdk.AudioConfig.fromWavFileInput(
new File([blob], "audio.wav")
@@ -74,6 +138,7 @@ export const usePronunciationAssessments = () => {
};
return {
createAssessment,
assess,
};
};

View File

@@ -178,7 +178,6 @@ type EnjoyAppType = {
update: (id: string, params: any) => Promise<RecordingType | undefined>;
destroy: (id: string) => Promise<void>;
upload: (id: string) => Promise<void>;
assess: (id: string, language?: string) => Promise<void>;
stats: (params: { from: string; to: string }) => Promise<{
count: number;
duration: number;

View File

@@ -4,6 +4,7 @@ type RecordingType = {
target?: AudioType | (MessageType & any);
targetId: string;
targetType: string;
language?: string;
pronunciationAssessment?: PronunciationAssessmentType & any;
referenceId: number;
referenceText?: string;