pr/fixes_plus_loading

#6
by nbpe97 - opened
whisper-speaker-diarization/package.json CHANGED
@@ -10,11 +10,13 @@
10
  "preview": "vite preview"
11
  },
12
  "dependencies": {
13
- "@xenova/transformers": "github:xenova/transformers.js#v3",
 
14
  "react": "^18.3.1",
15
  "react-dom": "^18.3.1"
16
  },
17
  "devDependencies": {
 
18
  "@types/react": "^18.3.3",
19
  "@types/react-dom": "^18.3.0",
20
  "@vitejs/plugin-react": "^4.3.1",
 
10
  "preview": "vite preview"
11
  },
12
  "dependencies": {
13
+ "@huggingface/transformers": "^3.3.1",
14
+ "prop-types": "^15.8.1",
15
  "react": "^18.3.1",
16
  "react-dom": "^18.3.1"
17
  },
18
  "devDependencies": {
19
+ "@rollup/plugin-commonjs": "^28.0.1",
20
  "@types/react": "^18.3.3",
21
  "@types/react-dom": "^18.3.0",
22
  "@vitejs/plugin-react": "^4.3.1",
whisper-speaker-diarization/src/App.jsx CHANGED
@@ -1,218 +1,257 @@
1
- import { useEffect, useState, useRef, useCallback } from 'react';
2
-
3
- import Progress from './components/Progress';
4
- import MediaInput from './components/MediaInput';
5
- import Transcript from './components/Transcript';
6
- import LanguageSelector from './components/LanguageSelector';
7
-
8
-
9
- async function hasWebGPU() {
10
- if (!navigator.gpu) {
11
- return false;
12
- }
13
- try {
14
- const adapter = await navigator.gpu.requestAdapter();
15
- return !!adapter;
16
- } catch (e) {
17
- return false;
18
- }
19
- }
20
-
21
- function App() {
22
-
23
- // Create a reference to the worker object.
24
- const worker = useRef(null);
25
-
26
- // Model loading and progress
27
- const [status, setStatus] = useState(null);
28
- const [loadingMessage, setLoadingMessage] = useState('');
29
- const [progressItems, setProgressItems] = useState([]);
30
-
31
- const mediaInputRef = useRef(null);
32
- const [audio, setAudio] = useState(null);
33
- const [language, setLanguage] = useState('en');
34
-
35
- const [result, setResult] = useState(null);
36
- const [time, setTime] = useState(null);
37
- const [currentTime, setCurrentTime] = useState(0);
38
-
39
- const [device, setDevice] = useState('webgpu'); // Try use WebGPU first
40
- const [modelSize, setModelSize] = useState('gpu' in navigator ? 196 : 77); // WebGPU=196MB, WebAssembly=77MB
41
- useEffect(() => {
42
- hasWebGPU().then((b) => {
43
- setModelSize(b ? 196 : 77);
44
- setDevice(b ? 'webgpu' : 'wasm');
45
- });
46
- }, []);
47
-
48
- // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
49
- useEffect(() => {
50
- if (!worker.current) {
51
- // Create the worker if it does not yet exist.
52
- worker.current = new Worker(new URL('./worker.js', import.meta.url), {
53
- type: 'module'
54
- });
55
- }
56
-
57
- // Create a callback function for messages from the worker thread.
58
- const onMessageReceived = (e) => {
59
- switch (e.data.status) {
60
- case 'loading':
61
- // Model file start load: add a new progress item to the list.
62
- setStatus('loading');
63
- setLoadingMessage(e.data.data);
64
- break;
65
-
66
- case 'initiate':
67
- setProgressItems(prev => [...prev, e.data]);
68
- break;
69
-
70
- case 'progress':
71
- // Model file progress: update one of the progress items.
72
- setProgressItems(
73
- prev => prev.map(item => {
74
- if (item.file === e.data.file) {
75
- return { ...item, ...e.data }
76
- }
77
- return item;
78
- })
79
- );
80
- break;
81
-
82
- case 'done':
83
- // Model file loaded: remove the progress item from the list.
84
- setProgressItems(
85
- prev => prev.filter(item => item.file !== e.data.file)
86
- );
87
- break;
88
-
89
- case 'loaded':
90
- // Pipeline ready: the worker is ready to accept messages.
91
- setStatus('ready');
92
- break;
93
-
94
- case 'complete':
95
- setResult(e.data.result);
96
- setTime(e.data.time);
97
- setStatus('ready');
98
- break;
99
- }
100
- };
101
-
102
- // Attach the callback function as an event listener.
103
- worker.current.addEventListener('message', onMessageReceived);
104
-
105
- // Define a cleanup function for when the component is unmounted.
106
- return () => {
107
- worker.current.removeEventListener('message', onMessageReceived);
108
- };
109
- }, []);
110
-
111
- const handleClick = useCallback(() => {
112
- setResult(null);
113
- setTime(null);
114
- if (status === null) {
115
- setStatus('loading');
116
- worker.current.postMessage({ type: 'load', data: { device } });
117
- } else {
118
- setStatus('running');
119
- worker.current.postMessage({
120
- type: 'run', data: { audio, language }
121
- });
122
- }
123
- }, [status, audio, language, device]);
124
-
125
- return (
126
- <div className="flex flex-col h-screen mx-auto text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900 max-w-[600px]">
127
-
128
- {status === 'loading' && (
129
- <div className="flex justify-center items-center fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] top-0 left-0">
130
- <div className="w-[500px]">
131
- <p className="text-center mb-1 text-white text-md">{loadingMessage}</p>
132
- {progressItems.map(({ file, progress, total }, i) => (
133
- <Progress key={i} text={file} percentage={progress} total={total} />
134
- ))}
135
- </div>
136
- </div>
137
- )}
138
- <div className="my-auto">
139
- <div className="flex flex-col items-center mb-2 text-center">
140
- <h1 className="text-5xl font-bold mb-2">Whisper Diarization</h1>
141
- <h2 className="text-xl font-semibold">In-browser automatic speech recognition w/ <br />word-level timestamps and speaker segmentation</h2>
142
- </div>
143
-
144
- <div className="w-full min-h-[220px] flex flex-col justify-center items-center">
145
- {
146
- !audio && (
147
- <p className="mb-2">
148
- You are about to download <a href="https://huggingface.co/onnx-community/whisper-base_timestamped" target="_blank" rel="noreferrer" className="font-medium underline">whisper-base</a> and <a href="https://huggingface.co/onnx-community/pyannote-segmentation-3.0" target="_blank" rel="noreferrer" className="font-medium underline">pyannote-segmentation-3.0</a>,
149
- two powerful speech recognition models for generating word-level timestamps across 100 different languages and speaker segmentation, respectively.
150
- Once loaded, the models ({modelSize}MB + 6MB) will be cached and reused when you revisit the page.<br />
151
- <br />
152
- Everything runs locally in your browser using <a href="https://huggingface.co/docs/transformers.js" target="_blank" rel="noreferrer" className="underline">🤗&nbsp;Transformers.js</a> and ONNX Runtime Web,
153
- meaning no API calls are made to a server for inference. You can even disconnect from the internet after the model has loaded!
154
- </p>
155
- )
156
- }
157
-
158
- <div className="flex flex-col w-full m-3 max-w-[520px]">
159
- <span className="text-sm mb-0.5">Input audio/video</span>
160
- <MediaInput
161
- ref={mediaInputRef}
162
- className="flex items-center border rounded-md cursor-pointer min-h-[100px] max-h-[500px] overflow-hidden"
163
- onInputChange={(audio) => {
164
- setResult(null);
165
- setAudio(audio);
166
- }}
167
- onTimeUpdate={(time) => setCurrentTime(time)}
168
- />
169
- </div>
170
-
171
- <div className="relative w-full flex justify-center items-center">
172
- <button
173
- className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
174
- onClick={handleClick}
175
- disabled={status === 'running' || (status !== null && audio === null)}
176
- >
177
- {status === null ? 'Load model' :
178
- status === 'running'
179
- ? 'Running...'
180
- : 'Run model'
181
- }
182
- </button>
183
-
184
- {status !== null &&
185
- <div className='absolute right-0 bottom-0'>
186
- <span className="text-xs">Language:</span>
187
- <br />
188
- <LanguageSelector className="border rounded-lg p-1 max-w-[100px]" language={language} setLanguage={setLanguage} />
189
- </div>
190
- }
191
- </div>
192
-
193
- {
194
- result && time && (
195
- <>
196
- <div className="w-full mt-4 border rounded-md">
197
- <Transcript
198
- className="p-2 max-h-[200px] overflow-y-auto scrollbar-thin select-none"
199
- transcript={result.transcript}
200
- segments={result.segments}
201
- currentTime={currentTime}
202
- setCurrentTime={(time) => {
203
- setCurrentTime(time);
204
- mediaInputRef.current.setMediaTime(time);
205
- }}
206
- />
207
- </div>
208
- <p className="text-sm text-gray-600 text-end p-1">Generation time: <span className="text-gray-800 font-semibold">{time.toFixed(2)}ms</span></p>
209
- </>
210
- )
211
- }
212
- </div>
213
- </div>
214
- </div >
215
- )
216
- }
217
-
218
- export default App
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useState, useRef, useCallback } from 'react';
2
+
3
+ import Progress from './components/Progress';
4
+ import MediaInput from './components/MediaInput';
5
+ import Transcript from './components/Transcript';
6
+ import LanguageSelector from './components/LanguageSelector';
7
+
8
+
9
+ async function hasWebGPU() {
10
+ if (!navigator.gpu) {
11
+ return false;
12
+ }
13
+ try {
14
+ const adapter = await navigator.gpu.requestAdapter();
15
+ return !!adapter;
16
+ } catch (e) {
17
+ return false;
18
+ }
19
+ }
20
+
21
+ function App() {
22
+
23
+ // Create a reference to the worker object.
24
+ const worker = useRef(null);
25
+
26
+ // Model loading and progress
27
+ const [status, setStatus] = useState(null);
28
+ const [loadingMessage, setLoadingMessage] = useState('');
29
+ const [progressItems, setProgressItems] = useState([]);
30
+
31
+ const mediaInputRef = useRef(null);
32
+ const [audio, setAudio] = useState(null);
33
+ const [language, setLanguage] = useState('en');
34
+
35
+ const [result, setResult] = useState(null);
36
+ const [time, setTime] = useState(null);
37
+ const [audioLength, setAudioLength] = useState(null);
38
+ const [currentTime, setCurrentTime] = useState(0);
39
+
40
+ const [device, setDevice] = useState('webgpu'); // Try use WebGPU first
41
+ const [modelSize, setModelSize] = useState('gpu' in navigator ? 196 : 77); // WebGPU=196MB, WebAssembly=77MB
42
+ useEffect(() => {
43
+ hasWebGPU().then((b) => {
44
+ setModelSize(b ? 196 : 77);
45
+ setDevice(b ? 'webgpu' : 'wasm');
46
+ });
47
+ }, []);
48
+
49
+ // Create a callback function for messages from the worker thread.
50
+ const onMessageReceived = (e) => {
51
+ switch (e.data.status) {
52
+ case 'loading':
53
+ // Model file start load: add a new progress item to the list.
54
+ setStatus('loading');
55
+ setLoadingMessage(e.data.data);
56
+ break;
57
+
58
+ case 'initiate':
59
+ setProgressItems(prev => [...prev, e.data]);
60
+ break;
61
+
62
+ case 'progress':
63
+ // Model file progress: update one of the progress items.
64
+ setProgressItems(
65
+ prev => prev.map(item => {
66
+ if (item.file === e.data.file) {
67
+ return { ...item, ...e.data }
68
+ }
69
+ return item;
70
+ })
71
+ );
72
+ break;
73
+
74
+ case 'done':
75
+ // Model file loaded: remove the progress item from the list.
76
+ setProgressItems(
77
+ prev => prev.filter(item => item.file !== e.data.file)
78
+ );
79
+ break;
80
+
81
+ case 'loaded':
82
+ // Pipeline ready: the worker is ready to accept messages.
83
+ setStatus('ready');
84
+ break;
85
+
86
+ case 'transcribe-progress': {
87
+ // Update progress for transcription/diarization
88
+ const { task, progress, total } = e.data.data;
89
+ setProgressItems(prev => {
90
+ const existingIndex = prev.findIndex(item => item.file === task);
91
+ if (existingIndex >= 0) {
92
+ return prev.map((item, i) =>
93
+ i === existingIndex ? { ...item, progress, total } : item
94
+ );
95
+ }
96
+ const newItem = { file: task, progress, total };
97
+ return [...prev, newItem];
98
+ });
99
+ break;
100
+ }
101
+
102
+ case 'complete':
103
+ setResult(e.data.result);
104
+ setTime(e.data.time);
105
+ setAudioLength(e.data.audio_length);
106
+ setStatus('ready');
107
+ break;
108
+ }
109
+ };
110
+
111
+ // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
112
+ useEffect(() => {
113
+ if (!worker.current) {
114
+ // Create the worker if it does not yet exist.
115
+ worker.current = new Worker(new URL('./worker.js', import.meta.url), {
116
+ type: 'module'
117
+ });
118
+ }
119
+
120
+ // Attach the callback function as an event listener.
121
+ worker.current.addEventListener('message', onMessageReceived);
122
+
123
+ // Define a cleanup function for when the component is unmounted.
124
+ return () => {
125
+ worker.current.removeEventListener('message', onMessageReceived);
126
+ };
127
+ }, []);
128
+
129
+ const handleClick = useCallback(() => {
130
+ setResult(null);
131
+ setTime(null);
132
+ if (status === null) {
133
+ setStatus('loading');
134
+ worker.current.postMessage({ type: 'load', data: { device } });
135
+ } else {
136
+ setStatus('running');
137
+ worker.current.postMessage({
138
+ type: 'run', data: { audio, language }
139
+ });
140
+ }
141
+ }, [status, audio, language, device]);
142
+
143
+ return (
144
+ <div className="flex flex-col h-screen mx-auto text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900 max-w-[600px]">
145
+
146
+ {(status === 'loading' || status === 'running') && (
147
+ <div className="flex justify-center items-center fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] top-0 left-0">
148
+ <div className="w-[500px]">
149
+ <p className="text-center mb-1 text-white text-md">{loadingMessage}</p>
150
+ {progressItems
151
+ .sort((a, b) => {
152
+ // Define the order: transcription -> segmentation -> diarization
153
+ const order = { 'transcription': 0, 'segmentation': 1, 'diarization': 2 };
154
+ return (order[a.file] ?? 3) - (order[b.file] ?? 3);
155
+ })
156
+ .map(({ file, progress, total }, i) => (
157
+ <Progress
158
+ key={i}
159
+ text={file === 'transcription' ? 'Converting speech to text' :
160
+ file === 'segmentation' ? 'Detecting word timestamps' :
161
+ file === 'diarization' ? 'Identifying speakers' :
162
+ file}
163
+ percentage={progress}
164
+ total={total}
165
+ />
166
+ ))
167
+ }
168
+ </div>
169
+ </div>
170
+ )}
171
+ <div className="my-auto">
172
+ <div className="flex flex-col items-center mb-2 text-center">
173
+ <h1 className="text-5xl font-bold mb-2">Whisper Diarization</h1>
174
+ <h2 className="text-xl font-semibold">In-browser automatic speech recognition w/ <br />word-level timestamps and speaker segmentation</h2>
175
+ </div>
176
+
177
+ <div className="w-full min-h-[220px] flex flex-col justify-center items-center">
178
+ {
179
+ !audio && (
180
+ <p className="mb-2">
181
+ You are about to download <a href="https://huggingface.co/onnx-community/whisper-base_timestamped" target="_blank" rel="noreferrer" className="font-medium underline">whisper-base</a> and <a href="https://huggingface.co/onnx-community/pyannote-segmentation-3.0" target="_blank" rel="noreferrer" className="font-medium underline">pyannote-segmentation-3.0</a>,
182
+ two powerful speech recognition models for generating word-level timestamps across 100 different languages and speaker segmentation, respectively.
183
+ Once loaded, the models ({modelSize}MB + 6MB) will be cached and reused when you revisit the page.<br />
184
+ <br />
185
+ Everything runs locally in your browser using <a href="https://huggingface.co/docs/transformers.js" target="_blank" rel="noreferrer" className="underline">🤗&nbsp;Transformers.js</a> and ONNX Runtime Web,
186
+ meaning no API calls are made to a server for inference. You can even disconnect from the internet after the model has loaded!
187
+ </p>
188
+ )
189
+ }
190
+
191
+ <div className="flex flex-col w-full m-3 max-w-[520px]">
192
+ <span className="text-sm mb-0.5">Input audio/video</span>
193
+ <MediaInput
194
+ ref={mediaInputRef}
195
+ className="flex items-center border rounded-md cursor-pointer min-h-[100px] max-h-[500px] overflow-hidden"
196
+ onInputChange={(audio) => {
197
+ setResult(null);
198
+ setAudio(audio);
199
+ }}
200
+ onTimeUpdate={(time) => setCurrentTime(time)}
201
+ onMessage={onMessageReceived}
202
+ />
203
+ </div>
204
+
205
+ <div className="relative w-full flex justify-center items-center">
206
+ <button
207
+ className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
208
+ onClick={handleClick}
209
+ disabled={status === 'running' || (status !== null && audio === null)}
210
+ >
211
+ {status === null ? 'Load model' :
212
+ status === 'running'
213
+ ? 'Running...'
214
+ : 'Run model'
215
+ }
216
+ </button>
217
+
218
+ {status !== null &&
219
+ <div className='absolute right-0 bottom-0'>
220
+ <span className="text-xs">Language:</span>
221
+ <br />
222
+ <LanguageSelector className="border rounded-lg p-1 max-w-[100px]" language={language} setLanguage={setLanguage} />
223
+ </div>
224
+ }
225
+ </div>
226
+
227
+ {
228
+ result && time && (
229
+ <>
230
+ <div className="w-full mt-4 border rounded-md">
231
+ <Transcript
232
+ className="p-2 max-h-[200px] overflow-y-auto scrollbar-thin select-none"
233
+ transcript={result.transcript}
234
+ segments={result.segments}
235
+ currentTime={currentTime}
236
+ setCurrentTime={(time) => {
237
+ setCurrentTime(time);
238
+ mediaInputRef.current.setMediaTime(time);
239
+ }}
240
+ />
241
+ </div>
242
+ <p className="text-sm text-end p-1">Generation time:
243
+ <span className="font-semibold">{(time / 1000).toLocaleString()} s</span>
244
+ </p>
245
+ <p className="text-sm text-end p-1">
246
+ <span className="font-semibold">{(audioLength / (time / 1000)).toFixed(2)}x transcription!</span>
247
+ </p>
248
+ </>
249
+ )
250
+ }
251
+ </div>
252
+ </div>
253
+ </div >
254
+ )
255
+ }
256
+
257
+ export default App
whisper-speaker-diarization/src/components/MediaInput.jsx CHANGED
@@ -1,8 +1,8 @@
1
  import { useState, forwardRef, useRef, useImperativeHandle, useEffect, useCallback } from 'react';
2
-
3
  const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/hopper.webm';
4
 
5
- const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) => {
6
  // UI states
7
  const [dragging, setDragging] = useState(false);
8
  const fileInputRef = useRef(null);
@@ -89,7 +89,40 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
89
  const audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16_000 });
90
 
91
  try {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  const audioBuffer = await audioContext.decodeAudioData(buffer);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  let audio;
94
  if (audioBuffer.numberOfChannels === 2) {
95
  // Merge channels
@@ -145,8 +178,8 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
145
  onClick={handleClick}
146
  onDragOver={handleDragOver}
147
  onDrop={handleDrop}
148
- onDragEnter={(e) => setDragging(true)}
149
- onDragLeave={(e) => setDragging(false)}
150
  >
151
  <input
152
  type="file"
@@ -189,6 +222,13 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
189
  </div>
190
  );
191
  });
 
 
 
 
 
 
 
192
  MediaInput.displayName = 'MediaInput';
193
 
194
  export default MediaInput;
 
1
  import { useState, forwardRef, useRef, useImperativeHandle, useEffect, useCallback } from 'react';
2
+ import PropTypes from 'prop-types';
3
  const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/hopper.webm';
4
 
5
+ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, onMessage, ...props }, ref) => {
6
  // UI states
7
  const [dragging, setDragging] = useState(false);
8
  const fileInputRef = useRef(null);
 
89
  const audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16_000 });
90
 
91
  try {
92
+ // Start audio decoding
93
+ onMessage({
94
+ data: {
95
+ status: 'loading',
96
+ data: 'Decoding audio buffer...'
97
+ }
98
+ });
99
+
100
+ onMessage({
101
+ data: {
102
+ status: 'initiate',
103
+ name: 'audio-decoder',
104
+ file: 'audio-buffer'
105
+ }
106
+ });
107
+
108
  const audioBuffer = await audioContext.decodeAudioData(buffer);
109
+
110
+ // Audio decoding complete
111
+ onMessage({
112
+ data: {
113
+ status: 'done',
114
+ name: 'audio-decoder',
115
+ file: 'audio-buffer'
116
+ }
117
+ });
118
+
119
+ // Audio decoding complete
120
+ onMessage({
121
+ data: {
122
+ status: 'loaded'
123
+ }
124
+ });
125
+
126
  let audio;
127
  if (audioBuffer.numberOfChannels === 2) {
128
  // Merge channels
 
178
  onClick={handleClick}
179
  onDragOver={handleDragOver}
180
  onDrop={handleDrop}
181
+ onDragEnter={() => setDragging(true)}
182
+ onDragLeave={() => setDragging(false)}
183
  >
184
  <input
185
  type="file"
 
222
  </div>
223
  );
224
  });
225
+
226
+ MediaInput.propTypes = {
227
+ onInputChange: PropTypes.func.isRequired,
228
+ onTimeUpdate: PropTypes.func.isRequired,
229
+ onMessage: PropTypes.func.isRequired
230
+ };
231
+
232
  MediaInput.displayName = 'MediaInput';
233
 
234
  export default MediaInput;
whisper-speaker-diarization/src/components/Progress.jsx CHANGED
@@ -1,3 +1,5 @@
 
 
1
  function formatBytes(size) {
2
  const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
3
  return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
@@ -13,3 +15,9 @@ export default function Progress({ text, percentage, total }) {
13
  </div>
14
  );
15
  }
 
 
 
 
 
 
 
1
+ import PropTypes from 'prop-types';
2
+
3
  function formatBytes(size) {
4
  const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
5
  return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
 
15
  </div>
16
  );
17
  }
18
+
19
+ Progress.propTypes = {
20
+ text: PropTypes.string.isRequired,
21
+ percentage: PropTypes.number,
22
+ total: PropTypes.number
23
+ };
whisper-speaker-diarization/src/worker.js CHANGED
@@ -1,124 +1,272 @@
1
-
2
- import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@xenova/transformers';
3
-
4
- const PER_DEVICE_CONFIG = {
5
- webgpu: {
6
- dtype: {
7
- encoder_model: 'fp32',
8
- decoder_model_merged: 'q4',
9
- },
10
- device: 'webgpu',
11
- },
12
- wasm: {
13
- dtype: 'q8',
14
- device: 'wasm',
15
- },
16
- };
17
-
18
- /**
19
- * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
20
- */
21
- class PipelineSingeton {
22
- static asr_model_id = 'onnx-community/whisper-base_timestamped';
23
- static asr_instance = null;
24
-
25
- static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
26
- static segmentation_instance = null;
27
- static segmentation_processor = null;
28
-
29
- static async getInstance(progress_callback = null, device = 'webgpu') {
30
- this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
31
- ...PER_DEVICE_CONFIG[device],
32
- progress_callback,
33
- });
34
-
35
- this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
36
- progress_callback,
37
- });
38
- this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
39
- // NOTE: WebGPU is not currently supported for this model
40
- // See https://github.com/microsoft/onnxruntime/issues/21386
41
- device: 'wasm',
42
- dtype: 'fp32',
43
- progress_callback,
44
- });
45
-
46
- return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
47
- }
48
- }
49
-
50
- async function load({ device }) {
51
- self.postMessage({
52
- status: 'loading',
53
- data: `Loading models (${device})...`
54
- });
55
-
56
- // Load the pipeline and save it for future use.
57
- const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance(x => {
58
- // We also add a progress callback to the pipeline so that we can
59
- // track model loading.
60
- self.postMessage(x);
61
- }, device);
62
-
63
- if (device === 'webgpu') {
64
- self.postMessage({
65
- status: 'loading',
66
- data: 'Compiling shaders and warming up model...'
67
- });
68
-
69
- await transcriber(new Float32Array(16_000), {
70
- language: 'en',
71
- });
72
- }
73
-
74
- self.postMessage({ status: 'loaded' });
75
- }
76
-
77
- async function segment(processor, model, audio) {
78
- const inputs = await processor(audio);
79
- const { logits } = await model(inputs);
80
- const segments = processor.post_process_speaker_diarization(logits, audio.length)[0];
81
-
82
- // Attach labels
83
- for (const segment of segments) {
84
- segment.label = model.config.id2label[segment.id];
85
- }
86
-
87
- return segments;
88
- }
89
-
90
- async function run({ audio, language }) {
91
- const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance();
92
-
93
- const start = performance.now();
94
-
95
- // Run transcription and segmentation in parallel
96
- const [transcript, segments] = await Promise.all([
97
- transcriber(audio, {
98
- language,
99
- return_timestamps: 'word',
100
- chunk_length_s: 30,
101
- }),
102
- segment(segmentation_processor, segmentation_model, audio)
103
- ]);
104
- console.table(segments, ['start', 'end', 'id', 'label', 'confidence']);
105
-
106
- const end = performance.now();
107
-
108
- self.postMessage({ status: 'complete', result: { transcript, segments }, time: end - start });
109
- }
110
-
111
- // Listen for messages from the main thread
112
- self.addEventListener('message', async (e) => {
113
- const { type, data } = e.data;
114
-
115
- switch (type) {
116
- case 'load':
117
- load(data);
118
- break;
119
-
120
- case 'run':
121
- run(data);
122
- break;
123
- }
124
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@huggingface/transformers';
3
+
4
+ const PER_DEVICE_CONFIG = {
5
+ webgpu: {
6
+ dtype: {
7
+ encoder_model: 'fp32',
8
+ decoder_model_merged: 'q4',
9
+ },
10
+ device: 'webgpu',
11
+ },
12
+ wasm: {
13
+ dtype: 'q8',
14
+ device: 'wasm',
15
+ },
16
+ };
17
+
18
+ /**
19
+ * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
20
+ */
21
+ class PipelineSingeton {
22
+ static asr_model_id = 'onnx-community/whisper-base_timestamped';
23
+ static asr_instance = null;
24
+
25
+ static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
26
+ static segmentation_instance = null;
27
+ static segmentation_processor = null;
28
+
29
+ static async getInstance(progress_callback = null, device = 'webgpu') {
30
+ this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
31
+ ...PER_DEVICE_CONFIG[device],
32
+ progress_callback,
33
+ });
34
+
35
+ this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
36
+ progress_callback,
37
+ });
38
+ this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
39
+ // NOTE: WebGPU is not currently supported for this model
40
+ // See https://github.com/microsoft/onnxruntime/issues/21386
41
+ device: 'wasm',
42
+ dtype: 'fp32',
43
+ progress_callback,
44
+ });
45
+
46
+ return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
47
+ }
48
+ }
49
+
50
+ async function load({ device }) {
51
+ try {
52
+ const message = {
53
+ status: 'loading',
54
+ data: `Loading models (${device})...`
55
+ };
56
+ self.postMessage(message);
57
+
58
+ const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance(x => {
59
+ // We also add a progress callback to the pipeline so that we can
60
+ // track model loading.
61
+ self.postMessage(x);
62
+ }, device);
63
+
64
+ if (device === 'webgpu') {
65
+ const warmupMessage = {
66
+ status: 'loading',
67
+ data: 'Compiling shaders and warming up model...'
68
+ };
69
+
70
+ self.postMessage(warmupMessage);
71
+
72
+ await transcriber(new Float32Array(16_000), {
73
+ language: 'en',
74
+ });
75
+ }
76
+
77
+ self.postMessage({ status: 'loaded' });
78
+ } catch (error) {
79
+ console.error('Loading error:', error);
80
+ const errorMessage = {
81
+ status: 'error',
82
+ error: error.message || 'Failed to load models'
83
+ };
84
+ self.postMessage(errorMessage);
85
+ }
86
+ }
87
+
88
+ async function segment(processor, model, audio) {
89
+ try {
90
+ // Report start of segmentation
91
+ self.postMessage({
92
+ status: 'transcribe-progress',
93
+ data: {
94
+ task: 'segmentation',
95
+ progress: 0,
96
+ total: audio.length
97
+ }
98
+ });
99
+
100
+ // Process audio in chunks to show progress
101
+ const inputs = await processor(audio);
102
+
103
+ // Report segmentation feature extraction progress
104
+ self.postMessage({
105
+ status: 'transcribe-progress',
106
+ data: {
107
+ task: 'segmentation',
108
+ progress: 50,
109
+ total: audio.length
110
+ }
111
+ });
112
+
113
+ const { logits } = await model(inputs);
114
+
115
+ // Report segmentation completion
116
+ self.postMessage({
117
+ status: 'transcribe-progress',
118
+ data: {
119
+ task: 'segmentation',
120
+ progress: 100,
121
+ total: audio.length
122
+ }
123
+ });
124
+
125
+ // Start diarization
126
+ self.postMessage({
127
+ status: 'transcribe-progress',
128
+ data: {
129
+ task: 'diarization',
130
+ progress: 0,
131
+ total: audio.length
132
+ }
133
+ });
134
+
135
+ const segments = processor.post_process_speaker_diarization(logits, audio.length)[0];
136
+
137
+ // Attach labels and report diarization completion
138
+ for (const segment of segments) {
139
+ segment.label = model.config.id2label[segment.id];
140
+ }
141
+
142
+ self.postMessage({
143
+ status: 'transcribe-progress',
144
+ data: {
145
+ task: 'diarization',
146
+ progress: 100,
147
+ total: audio.length
148
+ }
149
+ });
150
+
151
+ return segments;
152
+ } catch (error) {
153
+ console.error('Segmentation error:', error);
154
+ return [{
155
+ id: 0,
156
+ start: 0,
157
+ end: (audio.length / 480016) * 30,
158
+ label: 'SPEAKER_00',
159
+ confidence: 1.0
160
+ }];
161
+ }
162
+ }
163
+
164
+ async function run({ audio, language }) {
165
+ try {
166
+ const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance();
167
+
168
+ const audioLengthSeconds = (audio.length / 16000);
169
+
170
+ // Initialize transcription progress
171
+ self.postMessage({
172
+ status: 'transcribe-progress',
173
+ data: {
174
+ task: 'transcription',
175
+ progress: 0,
176
+ total: audio.length
177
+ }
178
+ });
179
+
180
+ const start = performance.now();
181
+ // Process in 30-second chunks
182
+ const CHUNK_SIZE = 3 * 30 * 16000; // 30 seconds * 16000 samples/second
183
+ const numChunks = Math.ceil(audio.length / CHUNK_SIZE);
184
+ let transcriptResults = [];
185
+
186
+ for (let i = 0; i < numChunks; i++) {
187
+ const start = i * CHUNK_SIZE;
188
+ const end = Math.min((i + 1) * CHUNK_SIZE, audio.length);
189
+ const chunk = audio.slice(start, end);
190
+
191
+ // Process chunk
192
+ const chunkResult = await transcriber(chunk, {
193
+ language,
194
+ return_timestamps: 'word',
195
+ chunk_length_s: 30,
196
+ });
197
+ const progressMessage = {
198
+ status: 'transcribe-progress',
199
+ data: {
200
+ task: 'transcription',
201
+ progress: Math.round((i+1) / numChunks * 100),
202
+ total: audio.length
203
+ }
204
+ };
205
+ self.postMessage(progressMessage);
206
+
207
+
208
+ // Adjust timestamps for this chunk
209
+ if (chunkResult.chunks) {
210
+ chunkResult.chunks.forEach(chunk => {
211
+ if (chunk.timestamp) {
212
+ chunk.timestamp[0] += start / 16000; // Convert samples to seconds
213
+ chunk.timestamp[1] += start / 16000;
214
+ }
215
+ });
216
+ }
217
+
218
+ transcriptResults.push(chunkResult);
219
+ }
220
+
221
+ // Combine results
222
+ const transcript = {
223
+ text: transcriptResults.map(r => r.text).join(''),
224
+ chunks: transcriptResults.flatMap(r => r.chunks || [])
225
+ };
226
+
227
+ // Run segmentation in parallel with the last chunk
228
+ const segments = await segment(segmentation_processor, segmentation_model, audio);
229
+
230
+ // Ensure transcription shows as complete
231
+ self.postMessage({
232
+ status: 'transcribe-progress',
233
+ data: {
234
+ task: 'transcription',
235
+ progress: 100,
236
+ total: audio.length
237
+ }
238
+ });
239
+ const end = performance.now();
240
+
241
+ const completeMessage = {
242
+ status: 'complete',
243
+ result: { transcript, segments },
244
+ audio_length: audioLengthSeconds,
245
+ time: end - start
246
+ };
247
+ self.postMessage(completeMessage);
248
+ } catch (error) {
249
+ console.error('Processing error:', error);
250
+ const errorMessage = {
251
+ status: 'error',
252
+ error: error.message || 'Failed to process audio'
253
+ };
254
+ console.log('Worker sending error:', errorMessage);
255
+ self.postMessage(errorMessage);
256
+ }
257
+ }
258
+
259
+ // Listen for messages from the main thread
260
+ self.addEventListener('message', async (e) => {
261
+ const { type, data } = e.data;
262
+
263
+ switch (type) {
264
+ case 'load':
265
+ load(data);
266
+ break;
267
+
268
+ case 'run':
269
+ run(data);
270
+ break;
271
+ }
272
+ });