pr/fixes_plus_loading
#6
by
nbpe97
- opened
whisper-speaker-diarization/package.json
CHANGED
@@ -10,11 +10,13 @@
|
|
10 |
"preview": "vite preview"
|
11 |
},
|
12 |
"dependencies": {
|
13 |
-
"@
|
|
|
14 |
"react": "^18.3.1",
|
15 |
"react-dom": "^18.3.1"
|
16 |
},
|
17 |
"devDependencies": {
|
|
|
18 |
"@types/react": "^18.3.3",
|
19 |
"@types/react-dom": "^18.3.0",
|
20 |
"@vitejs/plugin-react": "^4.3.1",
|
|
|
10 |
"preview": "vite preview"
|
11 |
},
|
12 |
"dependencies": {
|
13 |
+
"@huggingface/transformers": "^3.3.1",
|
14 |
+
"prop-types": "^15.8.1",
|
15 |
"react": "^18.3.1",
|
16 |
"react-dom": "^18.3.1"
|
17 |
},
|
18 |
"devDependencies": {
|
19 |
+
"@rollup/plugin-commonjs": "^28.0.1",
|
20 |
"@types/react": "^18.3.3",
|
21 |
"@types/react-dom": "^18.3.0",
|
22 |
"@vitejs/plugin-react": "^4.3.1",
|
whisper-speaker-diarization/src/App.jsx
CHANGED
@@ -1,218 +1,257 @@
|
|
1 |
-
import { useEffect, useState, useRef, useCallback } from 'react';
|
2 |
-
|
3 |
-
import Progress from './components/Progress';
|
4 |
-
import MediaInput from './components/MediaInput';
|
5 |
-
import Transcript from './components/Transcript';
|
6 |
-
import LanguageSelector from './components/LanguageSelector';
|
7 |
-
|
8 |
-
|
9 |
-
async function hasWebGPU() {
|
10 |
-
if (!navigator.gpu) {
|
11 |
-
return false;
|
12 |
-
}
|
13 |
-
try {
|
14 |
-
const adapter = await navigator.gpu.requestAdapter();
|
15 |
-
return !!adapter;
|
16 |
-
} catch (e) {
|
17 |
-
return false;
|
18 |
-
}
|
19 |
-
}
|
20 |
-
|
21 |
-
function App() {
|
22 |
-
|
23 |
-
// Create a reference to the worker object.
|
24 |
-
const worker = useRef(null);
|
25 |
-
|
26 |
-
// Model loading and progress
|
27 |
-
const [status, setStatus] = useState(null);
|
28 |
-
const [loadingMessage, setLoadingMessage] = useState('');
|
29 |
-
const [progressItems, setProgressItems] = useState([]);
|
30 |
-
|
31 |
-
const mediaInputRef = useRef(null);
|
32 |
-
const [audio, setAudio] = useState(null);
|
33 |
-
const [language, setLanguage] = useState('en');
|
34 |
-
|
35 |
-
const [result, setResult] = useState(null);
|
36 |
-
const [time, setTime] = useState(null);
|
37 |
-
const [
|
38 |
-
|
39 |
-
|
40 |
-
const [
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
}
|
109 |
-
}
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { useEffect, useState, useRef, useCallback } from 'react';
|
2 |
+
|
3 |
+
import Progress from './components/Progress';
|
4 |
+
import MediaInput from './components/MediaInput';
|
5 |
+
import Transcript from './components/Transcript';
|
6 |
+
import LanguageSelector from './components/LanguageSelector';
|
7 |
+
|
8 |
+
|
9 |
+
async function hasWebGPU() {
|
10 |
+
if (!navigator.gpu) {
|
11 |
+
return false;
|
12 |
+
}
|
13 |
+
try {
|
14 |
+
const adapter = await navigator.gpu.requestAdapter();
|
15 |
+
return !!adapter;
|
16 |
+
} catch (e) {
|
17 |
+
return false;
|
18 |
+
}
|
19 |
+
}
|
20 |
+
|
21 |
+
function App() {
|
22 |
+
|
23 |
+
// Create a reference to the worker object.
|
24 |
+
const worker = useRef(null);
|
25 |
+
|
26 |
+
// Model loading and progress
|
27 |
+
const [status, setStatus] = useState(null);
|
28 |
+
const [loadingMessage, setLoadingMessage] = useState('');
|
29 |
+
const [progressItems, setProgressItems] = useState([]);
|
30 |
+
|
31 |
+
const mediaInputRef = useRef(null);
|
32 |
+
const [audio, setAudio] = useState(null);
|
33 |
+
const [language, setLanguage] = useState('en');
|
34 |
+
|
35 |
+
const [result, setResult] = useState(null);
|
36 |
+
const [time, setTime] = useState(null);
|
37 |
+
const [audioLength, setAudioLength] = useState(null);
|
38 |
+
const [currentTime, setCurrentTime] = useState(0);
|
39 |
+
|
40 |
+
const [device, setDevice] = useState('webgpu'); // Try use WebGPU first
|
41 |
+
const [modelSize, setModelSize] = useState('gpu' in navigator ? 196 : 77); // WebGPU=196MB, WebAssembly=77MB
|
42 |
+
useEffect(() => {
|
43 |
+
hasWebGPU().then((b) => {
|
44 |
+
setModelSize(b ? 196 : 77);
|
45 |
+
setDevice(b ? 'webgpu' : 'wasm');
|
46 |
+
});
|
47 |
+
}, []);
|
48 |
+
|
49 |
+
// Create a callback function for messages from the worker thread.
|
50 |
+
const onMessageReceived = (e) => {
|
51 |
+
switch (e.data.status) {
|
52 |
+
case 'loading':
|
53 |
+
// Model file start load: add a new progress item to the list.
|
54 |
+
setStatus('loading');
|
55 |
+
setLoadingMessage(e.data.data);
|
56 |
+
break;
|
57 |
+
|
58 |
+
case 'initiate':
|
59 |
+
setProgressItems(prev => [...prev, e.data]);
|
60 |
+
break;
|
61 |
+
|
62 |
+
case 'progress':
|
63 |
+
// Model file progress: update one of the progress items.
|
64 |
+
setProgressItems(
|
65 |
+
prev => prev.map(item => {
|
66 |
+
if (item.file === e.data.file) {
|
67 |
+
return { ...item, ...e.data }
|
68 |
+
}
|
69 |
+
return item;
|
70 |
+
})
|
71 |
+
);
|
72 |
+
break;
|
73 |
+
|
74 |
+
case 'done':
|
75 |
+
// Model file loaded: remove the progress item from the list.
|
76 |
+
setProgressItems(
|
77 |
+
prev => prev.filter(item => item.file !== e.data.file)
|
78 |
+
);
|
79 |
+
break;
|
80 |
+
|
81 |
+
case 'loaded':
|
82 |
+
// Pipeline ready: the worker is ready to accept messages.
|
83 |
+
setStatus('ready');
|
84 |
+
break;
|
85 |
+
|
86 |
+
case 'transcribe-progress': {
|
87 |
+
// Update progress for transcription/diarization
|
88 |
+
const { task, progress, total } = e.data.data;
|
89 |
+
setProgressItems(prev => {
|
90 |
+
const existingIndex = prev.findIndex(item => item.file === task);
|
91 |
+
if (existingIndex >= 0) {
|
92 |
+
return prev.map((item, i) =>
|
93 |
+
i === existingIndex ? { ...item, progress, total } : item
|
94 |
+
);
|
95 |
+
}
|
96 |
+
const newItem = { file: task, progress, total };
|
97 |
+
return [...prev, newItem];
|
98 |
+
});
|
99 |
+
break;
|
100 |
+
}
|
101 |
+
|
102 |
+
case 'complete':
|
103 |
+
setResult(e.data.result);
|
104 |
+
setTime(e.data.time);
|
105 |
+
setAudioLength(e.data.audio_length);
|
106 |
+
setStatus('ready');
|
107 |
+
break;
|
108 |
+
}
|
109 |
+
};
|
110 |
+
|
111 |
+
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
|
112 |
+
useEffect(() => {
|
113 |
+
if (!worker.current) {
|
114 |
+
// Create the worker if it does not yet exist.
|
115 |
+
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
|
116 |
+
type: 'module'
|
117 |
+
});
|
118 |
+
}
|
119 |
+
|
120 |
+
// Attach the callback function as an event listener.
|
121 |
+
worker.current.addEventListener('message', onMessageReceived);
|
122 |
+
|
123 |
+
// Define a cleanup function for when the component is unmounted.
|
124 |
+
return () => {
|
125 |
+
worker.current.removeEventListener('message', onMessageReceived);
|
126 |
+
};
|
127 |
+
}, []);
|
128 |
+
|
129 |
+
const handleClick = useCallback(() => {
|
130 |
+
setResult(null);
|
131 |
+
setTime(null);
|
132 |
+
if (status === null) {
|
133 |
+
setStatus('loading');
|
134 |
+
worker.current.postMessage({ type: 'load', data: { device } });
|
135 |
+
} else {
|
136 |
+
setStatus('running');
|
137 |
+
worker.current.postMessage({
|
138 |
+
type: 'run', data: { audio, language }
|
139 |
+
});
|
140 |
+
}
|
141 |
+
}, [status, audio, language, device]);
|
142 |
+
|
143 |
+
return (
|
144 |
+
<div className="flex flex-col h-screen mx-auto text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900 max-w-[600px]">
|
145 |
+
|
146 |
+
{(status === 'loading' || status === 'running') && (
|
147 |
+
<div className="flex justify-center items-center fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] top-0 left-0">
|
148 |
+
<div className="w-[500px]">
|
149 |
+
<p className="text-center mb-1 text-white text-md">{loadingMessage}</p>
|
150 |
+
{progressItems
|
151 |
+
.sort((a, b) => {
|
152 |
+
// Define the order: transcription -> segmentation -> diarization
|
153 |
+
const order = { 'transcription': 0, 'segmentation': 1, 'diarization': 2 };
|
154 |
+
return (order[a.file] ?? 3) - (order[b.file] ?? 3);
|
155 |
+
})
|
156 |
+
.map(({ file, progress, total }, i) => (
|
157 |
+
<Progress
|
158 |
+
key={i}
|
159 |
+
text={file === 'transcription' ? 'Converting speech to text' :
|
160 |
+
file === 'segmentation' ? 'Detecting word timestamps' :
|
161 |
+
file === 'diarization' ? 'Identifying speakers' :
|
162 |
+
file}
|
163 |
+
percentage={progress}
|
164 |
+
total={total}
|
165 |
+
/>
|
166 |
+
))
|
167 |
+
}
|
168 |
+
</div>
|
169 |
+
</div>
|
170 |
+
)}
|
171 |
+
<div className="my-auto">
|
172 |
+
<div className="flex flex-col items-center mb-2 text-center">
|
173 |
+
<h1 className="text-5xl font-bold mb-2">Whisper Diarization</h1>
|
174 |
+
<h2 className="text-xl font-semibold">In-browser automatic speech recognition w/ <br />word-level timestamps and speaker segmentation</h2>
|
175 |
+
</div>
|
176 |
+
|
177 |
+
<div className="w-full min-h-[220px] flex flex-col justify-center items-center">
|
178 |
+
{
|
179 |
+
!audio && (
|
180 |
+
<p className="mb-2">
|
181 |
+
You are about to download <a href="https://huggingface.co/onnx-community/whisper-base_timestamped" target="_blank" rel="noreferrer" className="font-medium underline">whisper-base</a> and <a href="https://huggingface.co/onnx-community/pyannote-segmentation-3.0" target="_blank" rel="noreferrer" className="font-medium underline">pyannote-segmentation-3.0</a>,
|
182 |
+
two powerful speech recognition models for generating word-level timestamps across 100 different languages and speaker segmentation, respectively.
|
183 |
+
Once loaded, the models ({modelSize}MB + 6MB) will be cached and reused when you revisit the page.<br />
|
184 |
+
<br />
|
185 |
+
Everything runs locally in your browser using <a href="https://huggingface.co/docs/transformers.js" target="_blank" rel="noreferrer" className="underline">🤗 Transformers.js</a> and ONNX Runtime Web,
|
186 |
+
meaning no API calls are made to a server for inference. You can even disconnect from the internet after the model has loaded!
|
187 |
+
</p>
|
188 |
+
)
|
189 |
+
}
|
190 |
+
|
191 |
+
<div className="flex flex-col w-full m-3 max-w-[520px]">
|
192 |
+
<span className="text-sm mb-0.5">Input audio/video</span>
|
193 |
+
<MediaInput
|
194 |
+
ref={mediaInputRef}
|
195 |
+
className="flex items-center border rounded-md cursor-pointer min-h-[100px] max-h-[500px] overflow-hidden"
|
196 |
+
onInputChange={(audio) => {
|
197 |
+
setResult(null);
|
198 |
+
setAudio(audio);
|
199 |
+
}}
|
200 |
+
onTimeUpdate={(time) => setCurrentTime(time)}
|
201 |
+
onMessage={onMessageReceived}
|
202 |
+
/>
|
203 |
+
</div>
|
204 |
+
|
205 |
+
<div className="relative w-full flex justify-center items-center">
|
206 |
+
<button
|
207 |
+
className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
|
208 |
+
onClick={handleClick}
|
209 |
+
disabled={status === 'running' || (status !== null && audio === null)}
|
210 |
+
>
|
211 |
+
{status === null ? 'Load model' :
|
212 |
+
status === 'running'
|
213 |
+
? 'Running...'
|
214 |
+
: 'Run model'
|
215 |
+
}
|
216 |
+
</button>
|
217 |
+
|
218 |
+
{status !== null &&
|
219 |
+
<div className='absolute right-0 bottom-0'>
|
220 |
+
<span className="text-xs">Language:</span>
|
221 |
+
<br />
|
222 |
+
<LanguageSelector className="border rounded-lg p-1 max-w-[100px]" language={language} setLanguage={setLanguage} />
|
223 |
+
</div>
|
224 |
+
}
|
225 |
+
</div>
|
226 |
+
|
227 |
+
{
|
228 |
+
result && time && (
|
229 |
+
<>
|
230 |
+
<div className="w-full mt-4 border rounded-md">
|
231 |
+
<Transcript
|
232 |
+
className="p-2 max-h-[200px] overflow-y-auto scrollbar-thin select-none"
|
233 |
+
transcript={result.transcript}
|
234 |
+
segments={result.segments}
|
235 |
+
currentTime={currentTime}
|
236 |
+
setCurrentTime={(time) => {
|
237 |
+
setCurrentTime(time);
|
238 |
+
mediaInputRef.current.setMediaTime(time);
|
239 |
+
}}
|
240 |
+
/>
|
241 |
+
</div>
|
242 |
+
<p className="text-sm text-end p-1">Generation time:
|
243 |
+
<span className="font-semibold">{(time / 1000).toLocaleString()} s</span>
|
244 |
+
</p>
|
245 |
+
<p className="text-sm text-end p-1">
|
246 |
+
<span className="font-semibold">{(audioLength / (time / 1000)).toFixed(2)}x transcription!</span>
|
247 |
+
</p>
|
248 |
+
</>
|
249 |
+
)
|
250 |
+
}
|
251 |
+
</div>
|
252 |
+
</div>
|
253 |
+
</div >
|
254 |
+
)
|
255 |
+
}
|
256 |
+
|
257 |
+
export default App
|
whisper-speaker-diarization/src/components/MediaInput.jsx
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import { useState, forwardRef, useRef, useImperativeHandle, useEffect, useCallback } from 'react';
|
2 |
-
|
3 |
const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/hopper.webm';
|
4 |
|
5 |
-
const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) => {
|
6 |
// UI states
|
7 |
const [dragging, setDragging] = useState(false);
|
8 |
const fileInputRef = useRef(null);
|
@@ -89,7 +89,40 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
|
|
89 |
const audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16_000 });
|
90 |
|
91 |
try {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
const audioBuffer = await audioContext.decodeAudioData(buffer);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
let audio;
|
94 |
if (audioBuffer.numberOfChannels === 2) {
|
95 |
// Merge channels
|
@@ -145,8 +178,8 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
|
|
145 |
onClick={handleClick}
|
146 |
onDragOver={handleDragOver}
|
147 |
onDrop={handleDrop}
|
148 |
-
onDragEnter={(
|
149 |
-
onDragLeave={(
|
150 |
>
|
151 |
<input
|
152 |
type="file"
|
@@ -189,6 +222,13 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
|
|
189 |
</div>
|
190 |
);
|
191 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
MediaInput.displayName = 'MediaInput';
|
193 |
|
194 |
export default MediaInput;
|
|
|
1 |
import { useState, forwardRef, useRef, useImperativeHandle, useEffect, useCallback } from 'react';
|
2 |
+
import PropTypes from 'prop-types';
|
3 |
const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/hopper.webm';
|
4 |
|
5 |
+
const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, onMessage, ...props }, ref) => {
|
6 |
// UI states
|
7 |
const [dragging, setDragging] = useState(false);
|
8 |
const fileInputRef = useRef(null);
|
|
|
89 |
const audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16_000 });
|
90 |
|
91 |
try {
|
92 |
+
// Start audio decoding
|
93 |
+
onMessage({
|
94 |
+
data: {
|
95 |
+
status: 'loading',
|
96 |
+
data: 'Decoding audio buffer...'
|
97 |
+
}
|
98 |
+
});
|
99 |
+
|
100 |
+
onMessage({
|
101 |
+
data: {
|
102 |
+
status: 'initiate',
|
103 |
+
name: 'audio-decoder',
|
104 |
+
file: 'audio-buffer'
|
105 |
+
}
|
106 |
+
});
|
107 |
+
|
108 |
const audioBuffer = await audioContext.decodeAudioData(buffer);
|
109 |
+
|
110 |
+
// Audio decoding complete
|
111 |
+
onMessage({
|
112 |
+
data: {
|
113 |
+
status: 'done',
|
114 |
+
name: 'audio-decoder',
|
115 |
+
file: 'audio-buffer'
|
116 |
+
}
|
117 |
+
});
|
118 |
+
|
119 |
+
// Audio decoding complete
|
120 |
+
onMessage({
|
121 |
+
data: {
|
122 |
+
status: 'loaded'
|
123 |
+
}
|
124 |
+
});
|
125 |
+
|
126 |
let audio;
|
127 |
if (audioBuffer.numberOfChannels === 2) {
|
128 |
// Merge channels
|
|
|
178 |
onClick={handleClick}
|
179 |
onDragOver={handleDragOver}
|
180 |
onDrop={handleDrop}
|
181 |
+
onDragEnter={() => setDragging(true)}
|
182 |
+
onDragLeave={() => setDragging(false)}
|
183 |
>
|
184 |
<input
|
185 |
type="file"
|
|
|
222 |
</div>
|
223 |
);
|
224 |
});
|
225 |
+
|
226 |
+
MediaInput.propTypes = {
|
227 |
+
onInputChange: PropTypes.func.isRequired,
|
228 |
+
onTimeUpdate: PropTypes.func.isRequired,
|
229 |
+
onMessage: PropTypes.func.isRequired
|
230 |
+
};
|
231 |
+
|
232 |
MediaInput.displayName = 'MediaInput';
|
233 |
|
234 |
export default MediaInput;
|
whisper-speaker-diarization/src/components/Progress.jsx
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
function formatBytes(size) {
|
2 |
const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
|
3 |
return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
|
@@ -13,3 +15,9 @@ export default function Progress({ text, percentage, total }) {
|
|
13 |
</div>
|
14 |
);
|
15 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PropTypes from 'prop-types';
|
2 |
+
|
3 |
function formatBytes(size) {
|
4 |
const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
|
5 |
return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
|
|
|
15 |
</div>
|
16 |
);
|
17 |
}
|
18 |
+
|
19 |
+
Progress.propTypes = {
|
20 |
+
text: PropTypes.string.isRequired,
|
21 |
+
percentage: PropTypes.number,
|
22 |
+
total: PropTypes.number
|
23 |
+
};
|
whisper-speaker-diarization/src/worker.js
CHANGED
@@ -1,124 +1,272 @@
|
|
1 |
-
|
2 |
-
import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@
|
3 |
-
|
4 |
-
const PER_DEVICE_CONFIG = {
|
5 |
-
webgpu: {
|
6 |
-
dtype: {
|
7 |
-
encoder_model: 'fp32',
|
8 |
-
decoder_model_merged: 'q4',
|
9 |
-
},
|
10 |
-
device: 'webgpu',
|
11 |
-
},
|
12 |
-
wasm: {
|
13 |
-
dtype: 'q8',
|
14 |
-
device: 'wasm',
|
15 |
-
},
|
16 |
-
};
|
17 |
-
|
18 |
-
/**
|
19 |
-
* This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
|
20 |
-
*/
|
21 |
-
class PipelineSingeton {
|
22 |
-
static asr_model_id = 'onnx-community/whisper-base_timestamped';
|
23 |
-
static asr_instance = null;
|
24 |
-
|
25 |
-
static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
|
26 |
-
static segmentation_instance = null;
|
27 |
-
static segmentation_processor = null;
|
28 |
-
|
29 |
-
static async getInstance(progress_callback = null, device = 'webgpu') {
|
30 |
-
this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
|
31 |
-
...PER_DEVICE_CONFIG[device],
|
32 |
-
progress_callback,
|
33 |
-
});
|
34 |
-
|
35 |
-
this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
|
36 |
-
progress_callback,
|
37 |
-
});
|
38 |
-
this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
|
39 |
-
// NOTE: WebGPU is not currently supported for this model
|
40 |
-
// See https://github.com/microsoft/onnxruntime/issues/21386
|
41 |
-
device: 'wasm',
|
42 |
-
dtype: 'fp32',
|
43 |
-
progress_callback,
|
44 |
-
});
|
45 |
-
|
46 |
-
return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
|
47 |
-
}
|
48 |
-
}
|
49 |
-
|
50 |
-
async function load({ device }) {
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
}
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
}
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@huggingface/transformers';
|
3 |
+
|
4 |
+
const PER_DEVICE_CONFIG = {
|
5 |
+
webgpu: {
|
6 |
+
dtype: {
|
7 |
+
encoder_model: 'fp32',
|
8 |
+
decoder_model_merged: 'q4',
|
9 |
+
},
|
10 |
+
device: 'webgpu',
|
11 |
+
},
|
12 |
+
wasm: {
|
13 |
+
dtype: 'q8',
|
14 |
+
device: 'wasm',
|
15 |
+
},
|
16 |
+
};
|
17 |
+
|
18 |
+
/**
|
19 |
+
* This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
|
20 |
+
*/
|
21 |
+
class PipelineSingeton {
|
22 |
+
static asr_model_id = 'onnx-community/whisper-base_timestamped';
|
23 |
+
static asr_instance = null;
|
24 |
+
|
25 |
+
static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
|
26 |
+
static segmentation_instance = null;
|
27 |
+
static segmentation_processor = null;
|
28 |
+
|
29 |
+
static async getInstance(progress_callback = null, device = 'webgpu') {
|
30 |
+
this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
|
31 |
+
...PER_DEVICE_CONFIG[device],
|
32 |
+
progress_callback,
|
33 |
+
});
|
34 |
+
|
35 |
+
this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
|
36 |
+
progress_callback,
|
37 |
+
});
|
38 |
+
this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
|
39 |
+
// NOTE: WebGPU is not currently supported for this model
|
40 |
+
// See https://github.com/microsoft/onnxruntime/issues/21386
|
41 |
+
device: 'wasm',
|
42 |
+
dtype: 'fp32',
|
43 |
+
progress_callback,
|
44 |
+
});
|
45 |
+
|
46 |
+
return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
async function load({ device }) {
|
51 |
+
try {
|
52 |
+
const message = {
|
53 |
+
status: 'loading',
|
54 |
+
data: `Loading models (${device})...`
|
55 |
+
};
|
56 |
+
self.postMessage(message);
|
57 |
+
|
58 |
+
const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance(x => {
|
59 |
+
// We also add a progress callback to the pipeline so that we can
|
60 |
+
// track model loading.
|
61 |
+
self.postMessage(x);
|
62 |
+
}, device);
|
63 |
+
|
64 |
+
if (device === 'webgpu') {
|
65 |
+
const warmupMessage = {
|
66 |
+
status: 'loading',
|
67 |
+
data: 'Compiling shaders and warming up model...'
|
68 |
+
};
|
69 |
+
|
70 |
+
self.postMessage(warmupMessage);
|
71 |
+
|
72 |
+
await transcriber(new Float32Array(16_000), {
|
73 |
+
language: 'en',
|
74 |
+
});
|
75 |
+
}
|
76 |
+
|
77 |
+
self.postMessage({ status: 'loaded' });
|
78 |
+
} catch (error) {
|
79 |
+
console.error('Loading error:', error);
|
80 |
+
const errorMessage = {
|
81 |
+
status: 'error',
|
82 |
+
error: error.message || 'Failed to load models'
|
83 |
+
};
|
84 |
+
self.postMessage(errorMessage);
|
85 |
+
}
|
86 |
+
}
|
87 |
+
|
88 |
+
async function segment(processor, model, audio) {
|
89 |
+
try {
|
90 |
+
// Report start of segmentation
|
91 |
+
self.postMessage({
|
92 |
+
status: 'transcribe-progress',
|
93 |
+
data: {
|
94 |
+
task: 'segmentation',
|
95 |
+
progress: 0,
|
96 |
+
total: audio.length
|
97 |
+
}
|
98 |
+
});
|
99 |
+
|
100 |
+
// Process audio in chunks to show progress
|
101 |
+
const inputs = await processor(audio);
|
102 |
+
|
103 |
+
// Report segmentation feature extraction progress
|
104 |
+
self.postMessage({
|
105 |
+
status: 'transcribe-progress',
|
106 |
+
data: {
|
107 |
+
task: 'segmentation',
|
108 |
+
progress: 50,
|
109 |
+
total: audio.length
|
110 |
+
}
|
111 |
+
});
|
112 |
+
|
113 |
+
const { logits } = await model(inputs);
|
114 |
+
|
115 |
+
// Report segmentation completion
|
116 |
+
self.postMessage({
|
117 |
+
status: 'transcribe-progress',
|
118 |
+
data: {
|
119 |
+
task: 'segmentation',
|
120 |
+
progress: 100,
|
121 |
+
total: audio.length
|
122 |
+
}
|
123 |
+
});
|
124 |
+
|
125 |
+
// Start diarization
|
126 |
+
self.postMessage({
|
127 |
+
status: 'transcribe-progress',
|
128 |
+
data: {
|
129 |
+
task: 'diarization',
|
130 |
+
progress: 0,
|
131 |
+
total: audio.length
|
132 |
+
}
|
133 |
+
});
|
134 |
+
|
135 |
+
const segments = processor.post_process_speaker_diarization(logits, audio.length)[0];
|
136 |
+
|
137 |
+
// Attach labels and report diarization completion
|
138 |
+
for (const segment of segments) {
|
139 |
+
segment.label = model.config.id2label[segment.id];
|
140 |
+
}
|
141 |
+
|
142 |
+
self.postMessage({
|
143 |
+
status: 'transcribe-progress',
|
144 |
+
data: {
|
145 |
+
task: 'diarization',
|
146 |
+
progress: 100,
|
147 |
+
total: audio.length
|
148 |
+
}
|
149 |
+
});
|
150 |
+
|
151 |
+
return segments;
|
152 |
+
} catch (error) {
|
153 |
+
console.error('Segmentation error:', error);
|
154 |
+
return [{
|
155 |
+
id: 0,
|
156 |
+
start: 0,
|
157 |
+
end: (audio.length / 480016) * 30,
|
158 |
+
label: 'SPEAKER_00',
|
159 |
+
confidence: 1.0
|
160 |
+
}];
|
161 |
+
}
|
162 |
+
}
|
163 |
+
|
164 |
+
async function run({ audio, language }) {
|
165 |
+
try {
|
166 |
+
const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance();
|
167 |
+
|
168 |
+
const audioLengthSeconds = (audio.length / 16000);
|
169 |
+
|
170 |
+
// Initialize transcription progress
|
171 |
+
self.postMessage({
|
172 |
+
status: 'transcribe-progress',
|
173 |
+
data: {
|
174 |
+
task: 'transcription',
|
175 |
+
progress: 0,
|
176 |
+
total: audio.length
|
177 |
+
}
|
178 |
+
});
|
179 |
+
|
180 |
+
const start = performance.now();
|
181 |
+
// Process in 30-second chunks
|
182 |
+
const CHUNK_SIZE = 3 * 30 * 16000; // 30 seconds * 16000 samples/second
|
183 |
+
const numChunks = Math.ceil(audio.length / CHUNK_SIZE);
|
184 |
+
let transcriptResults = [];
|
185 |
+
|
186 |
+
for (let i = 0; i < numChunks; i++) {
|
187 |
+
const start = i * CHUNK_SIZE;
|
188 |
+
const end = Math.min((i + 1) * CHUNK_SIZE, audio.length);
|
189 |
+
const chunk = audio.slice(start, end);
|
190 |
+
|
191 |
+
// Process chunk
|
192 |
+
const chunkResult = await transcriber(chunk, {
|
193 |
+
language,
|
194 |
+
return_timestamps: 'word',
|
195 |
+
chunk_length_s: 30,
|
196 |
+
});
|
197 |
+
const progressMessage = {
|
198 |
+
status: 'transcribe-progress',
|
199 |
+
data: {
|
200 |
+
task: 'transcription',
|
201 |
+
progress: Math.round((i+1) / numChunks * 100),
|
202 |
+
total: audio.length
|
203 |
+
}
|
204 |
+
};
|
205 |
+
self.postMessage(progressMessage);
|
206 |
+
|
207 |
+
|
208 |
+
// Adjust timestamps for this chunk
|
209 |
+
if (chunkResult.chunks) {
|
210 |
+
chunkResult.chunks.forEach(chunk => {
|
211 |
+
if (chunk.timestamp) {
|
212 |
+
chunk.timestamp[0] += start / 16000; // Convert samples to seconds
|
213 |
+
chunk.timestamp[1] += start / 16000;
|
214 |
+
}
|
215 |
+
});
|
216 |
+
}
|
217 |
+
|
218 |
+
transcriptResults.push(chunkResult);
|
219 |
+
}
|
220 |
+
|
221 |
+
// Combine results
|
222 |
+
const transcript = {
|
223 |
+
text: transcriptResults.map(r => r.text).join(''),
|
224 |
+
chunks: transcriptResults.flatMap(r => r.chunks || [])
|
225 |
+
};
|
226 |
+
|
227 |
+
// Run segmentation in parallel with the last chunk
|
228 |
+
const segments = await segment(segmentation_processor, segmentation_model, audio);
|
229 |
+
|
230 |
+
// Ensure transcription shows as complete
|
231 |
+
self.postMessage({
|
232 |
+
status: 'transcribe-progress',
|
233 |
+
data: {
|
234 |
+
task: 'transcription',
|
235 |
+
progress: 100,
|
236 |
+
total: audio.length
|
237 |
+
}
|
238 |
+
});
|
239 |
+
const end = performance.now();
|
240 |
+
|
241 |
+
const completeMessage = {
|
242 |
+
status: 'complete',
|
243 |
+
result: { transcript, segments },
|
244 |
+
audio_length: audioLengthSeconds,
|
245 |
+
time: end - start
|
246 |
+
};
|
247 |
+
self.postMessage(completeMessage);
|
248 |
+
} catch (error) {
|
249 |
+
console.error('Processing error:', error);
|
250 |
+
const errorMessage = {
|
251 |
+
status: 'error',
|
252 |
+
error: error.message || 'Failed to process audio'
|
253 |
+
};
|
254 |
+
console.log('Worker sending error:', errorMessage);
|
255 |
+
self.postMessage(errorMessage);
|
256 |
+
}
|
257 |
+
}
|
258 |
+
|
259 |
+
// Listen for messages from the main thread
|
260 |
+
self.addEventListener('message', async (e) => {
|
261 |
+
const { type, data } = e.data;
|
262 |
+
|
263 |
+
switch (type) {
|
264 |
+
case 'load':
|
265 |
+
load(data);
|
266 |
+
break;
|
267 |
+
|
268 |
+
case 'run':
|
269 |
+
run(data);
|
270 |
+
break;
|
271 |
+
}
|
272 |
+
});
|