pr/fixes_plus_loading
#6
by
nbpe97
- opened
whisper-speaker-diarization/package.json
CHANGED
|
@@ -10,11 +10,13 @@
|
|
| 10 |
"preview": "vite preview"
|
| 11 |
},
|
| 12 |
"dependencies": {
|
| 13 |
-
"@
|
|
|
|
| 14 |
"react": "^18.3.1",
|
| 15 |
"react-dom": "^18.3.1"
|
| 16 |
},
|
| 17 |
"devDependencies": {
|
|
|
|
| 18 |
"@types/react": "^18.3.3",
|
| 19 |
"@types/react-dom": "^18.3.0",
|
| 20 |
"@vitejs/plugin-react": "^4.3.1",
|
|
|
|
| 10 |
"preview": "vite preview"
|
| 11 |
},
|
| 12 |
"dependencies": {
|
| 13 |
+
"@huggingface/transformers": "^3.3.1",
|
| 14 |
+
"prop-types": "^15.8.1",
|
| 15 |
"react": "^18.3.1",
|
| 16 |
"react-dom": "^18.3.1"
|
| 17 |
},
|
| 18 |
"devDependencies": {
|
| 19 |
+
"@rollup/plugin-commonjs": "^28.0.1",
|
| 20 |
"@types/react": "^18.3.3",
|
| 21 |
"@types/react-dom": "^18.3.0",
|
| 22 |
"@vitejs/plugin-react": "^4.3.1",
|
whisper-speaker-diarization/src/App.jsx
CHANGED
|
@@ -1,218 +1,257 @@
|
|
| 1 |
-
import { useEffect, useState, useRef, useCallback } from 'react';
|
| 2 |
-
|
| 3 |
-
import Progress from './components/Progress';
|
| 4 |
-
import MediaInput from './components/MediaInput';
|
| 5 |
-
import Transcript from './components/Transcript';
|
| 6 |
-
import LanguageSelector from './components/LanguageSelector';
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
async function hasWebGPU() {
|
| 10 |
-
if (!navigator.gpu) {
|
| 11 |
-
return false;
|
| 12 |
-
}
|
| 13 |
-
try {
|
| 14 |
-
const adapter = await navigator.gpu.requestAdapter();
|
| 15 |
-
return !!adapter;
|
| 16 |
-
} catch (e) {
|
| 17 |
-
return false;
|
| 18 |
-
}
|
| 19 |
-
}
|
| 20 |
-
|
| 21 |
-
function App() {
|
| 22 |
-
|
| 23 |
-
// Create a reference to the worker object.
|
| 24 |
-
const worker = useRef(null);
|
| 25 |
-
|
| 26 |
-
// Model loading and progress
|
| 27 |
-
const [status, setStatus] = useState(null);
|
| 28 |
-
const [loadingMessage, setLoadingMessage] = useState('');
|
| 29 |
-
const [progressItems, setProgressItems] = useState([]);
|
| 30 |
-
|
| 31 |
-
const mediaInputRef = useRef(null);
|
| 32 |
-
const [audio, setAudio] = useState(null);
|
| 33 |
-
const [language, setLanguage] = useState('en');
|
| 34 |
-
|
| 35 |
-
const [result, setResult] = useState(null);
|
| 36 |
-
const [time, setTime] = useState(null);
|
| 37 |
-
const [
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
const [
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
}
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useState, useRef, useCallback } from 'react';
|
| 2 |
+
|
| 3 |
+
import Progress from './components/Progress';
|
| 4 |
+
import MediaInput from './components/MediaInput';
|
| 5 |
+
import Transcript from './components/Transcript';
|
| 6 |
+
import LanguageSelector from './components/LanguageSelector';
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
async function hasWebGPU() {
|
| 10 |
+
if (!navigator.gpu) {
|
| 11 |
+
return false;
|
| 12 |
+
}
|
| 13 |
+
try {
|
| 14 |
+
const adapter = await navigator.gpu.requestAdapter();
|
| 15 |
+
return !!adapter;
|
| 16 |
+
} catch (e) {
|
| 17 |
+
return false;
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
function App() {
|
| 22 |
+
|
| 23 |
+
// Create a reference to the worker object.
|
| 24 |
+
const worker = useRef(null);
|
| 25 |
+
|
| 26 |
+
// Model loading and progress
|
| 27 |
+
const [status, setStatus] = useState(null);
|
| 28 |
+
const [loadingMessage, setLoadingMessage] = useState('');
|
| 29 |
+
const [progressItems, setProgressItems] = useState([]);
|
| 30 |
+
|
| 31 |
+
const mediaInputRef = useRef(null);
|
| 32 |
+
const [audio, setAudio] = useState(null);
|
| 33 |
+
const [language, setLanguage] = useState('en');
|
| 34 |
+
|
| 35 |
+
const [result, setResult] = useState(null);
|
| 36 |
+
const [time, setTime] = useState(null);
|
| 37 |
+
const [audioLength, setAudioLength] = useState(null);
|
| 38 |
+
const [currentTime, setCurrentTime] = useState(0);
|
| 39 |
+
|
| 40 |
+
const [device, setDevice] = useState('webgpu'); // Try use WebGPU first
|
| 41 |
+
const [modelSize, setModelSize] = useState('gpu' in navigator ? 196 : 77); // WebGPU=196MB, WebAssembly=77MB
|
| 42 |
+
useEffect(() => {
|
| 43 |
+
hasWebGPU().then((b) => {
|
| 44 |
+
setModelSize(b ? 196 : 77);
|
| 45 |
+
setDevice(b ? 'webgpu' : 'wasm');
|
| 46 |
+
});
|
| 47 |
+
}, []);
|
| 48 |
+
|
| 49 |
+
// Create a callback function for messages from the worker thread.
|
| 50 |
+
const onMessageReceived = (e) => {
|
| 51 |
+
switch (e.data.status) {
|
| 52 |
+
case 'loading':
|
| 53 |
+
// Model file start load: add a new progress item to the list.
|
| 54 |
+
setStatus('loading');
|
| 55 |
+
setLoadingMessage(e.data.data);
|
| 56 |
+
break;
|
| 57 |
+
|
| 58 |
+
case 'initiate':
|
| 59 |
+
setProgressItems(prev => [...prev, e.data]);
|
| 60 |
+
break;
|
| 61 |
+
|
| 62 |
+
case 'progress':
|
| 63 |
+
// Model file progress: update one of the progress items.
|
| 64 |
+
setProgressItems(
|
| 65 |
+
prev => prev.map(item => {
|
| 66 |
+
if (item.file === e.data.file) {
|
| 67 |
+
return { ...item, ...e.data }
|
| 68 |
+
}
|
| 69 |
+
return item;
|
| 70 |
+
})
|
| 71 |
+
);
|
| 72 |
+
break;
|
| 73 |
+
|
| 74 |
+
case 'done':
|
| 75 |
+
// Model file loaded: remove the progress item from the list.
|
| 76 |
+
setProgressItems(
|
| 77 |
+
prev => prev.filter(item => item.file !== e.data.file)
|
| 78 |
+
);
|
| 79 |
+
break;
|
| 80 |
+
|
| 81 |
+
case 'loaded':
|
| 82 |
+
// Pipeline ready: the worker is ready to accept messages.
|
| 83 |
+
setStatus('ready');
|
| 84 |
+
break;
|
| 85 |
+
|
| 86 |
+
case 'transcribe-progress': {
|
| 87 |
+
// Update progress for transcription/diarization
|
| 88 |
+
const { task, progress, total } = e.data.data;
|
| 89 |
+
setProgressItems(prev => {
|
| 90 |
+
const existingIndex = prev.findIndex(item => item.file === task);
|
| 91 |
+
if (existingIndex >= 0) {
|
| 92 |
+
return prev.map((item, i) =>
|
| 93 |
+
i === existingIndex ? { ...item, progress, total } : item
|
| 94 |
+
);
|
| 95 |
+
}
|
| 96 |
+
const newItem = { file: task, progress, total };
|
| 97 |
+
return [...prev, newItem];
|
| 98 |
+
});
|
| 99 |
+
break;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
case 'complete':
|
| 103 |
+
setResult(e.data.result);
|
| 104 |
+
setTime(e.data.time);
|
| 105 |
+
setAudioLength(e.data.audio_length);
|
| 106 |
+
setStatus('ready');
|
| 107 |
+
break;
|
| 108 |
+
}
|
| 109 |
+
};
|
| 110 |
+
|
| 111 |
+
// We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
|
| 112 |
+
useEffect(() => {
|
| 113 |
+
if (!worker.current) {
|
| 114 |
+
// Create the worker if it does not yet exist.
|
| 115 |
+
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
|
| 116 |
+
type: 'module'
|
| 117 |
+
});
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
// Attach the callback function as an event listener.
|
| 121 |
+
worker.current.addEventListener('message', onMessageReceived);
|
| 122 |
+
|
| 123 |
+
// Define a cleanup function for when the component is unmounted.
|
| 124 |
+
return () => {
|
| 125 |
+
worker.current.removeEventListener('message', onMessageReceived);
|
| 126 |
+
};
|
| 127 |
+
}, []);
|
| 128 |
+
|
| 129 |
+
const handleClick = useCallback(() => {
|
| 130 |
+
setResult(null);
|
| 131 |
+
setTime(null);
|
| 132 |
+
if (status === null) {
|
| 133 |
+
setStatus('loading');
|
| 134 |
+
worker.current.postMessage({ type: 'load', data: { device } });
|
| 135 |
+
} else {
|
| 136 |
+
setStatus('running');
|
| 137 |
+
worker.current.postMessage({
|
| 138 |
+
type: 'run', data: { audio, language }
|
| 139 |
+
});
|
| 140 |
+
}
|
| 141 |
+
}, [status, audio, language, device]);
|
| 142 |
+
|
| 143 |
+
return (
|
| 144 |
+
<div className="flex flex-col h-screen mx-auto text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900 max-w-[600px]">
|
| 145 |
+
|
| 146 |
+
{(status === 'loading' || status === 'running') && (
|
| 147 |
+
<div className="flex justify-center items-center fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] top-0 left-0">
|
| 148 |
+
<div className="w-[500px]">
|
| 149 |
+
<p className="text-center mb-1 text-white text-md">{loadingMessage}</p>
|
| 150 |
+
{progressItems
|
| 151 |
+
.sort((a, b) => {
|
| 152 |
+
// Define the order: transcription -> segmentation -> diarization
|
| 153 |
+
const order = { 'transcription': 0, 'segmentation': 1, 'diarization': 2 };
|
| 154 |
+
return (order[a.file] ?? 3) - (order[b.file] ?? 3);
|
| 155 |
+
})
|
| 156 |
+
.map(({ file, progress, total }, i) => (
|
| 157 |
+
<Progress
|
| 158 |
+
key={i}
|
| 159 |
+
text={file === 'transcription' ? 'Converting speech to text' :
|
| 160 |
+
file === 'segmentation' ? 'Detecting word timestamps' :
|
| 161 |
+
file === 'diarization' ? 'Identifying speakers' :
|
| 162 |
+
file}
|
| 163 |
+
percentage={progress}
|
| 164 |
+
total={total}
|
| 165 |
+
/>
|
| 166 |
+
))
|
| 167 |
+
}
|
| 168 |
+
</div>
|
| 169 |
+
</div>
|
| 170 |
+
)}
|
| 171 |
+
<div className="my-auto">
|
| 172 |
+
<div className="flex flex-col items-center mb-2 text-center">
|
| 173 |
+
<h1 className="text-5xl font-bold mb-2">Whisper Diarization</h1>
|
| 174 |
+
<h2 className="text-xl font-semibold">In-browser automatic speech recognition w/ <br />word-level timestamps and speaker segmentation</h2>
|
| 175 |
+
</div>
|
| 176 |
+
|
| 177 |
+
<div className="w-full min-h-[220px] flex flex-col justify-center items-center">
|
| 178 |
+
{
|
| 179 |
+
!audio && (
|
| 180 |
+
<p className="mb-2">
|
| 181 |
+
You are about to download <a href="https://huggingface.co/onnx-community/whisper-base_timestamped" target="_blank" rel="noreferrer" className="font-medium underline">whisper-base</a> and <a href="https://huggingface.co/onnx-community/pyannote-segmentation-3.0" target="_blank" rel="noreferrer" className="font-medium underline">pyannote-segmentation-3.0</a>,
|
| 182 |
+
two powerful speech recognition models for generating word-level timestamps across 100 different languages and speaker segmentation, respectively.
|
| 183 |
+
Once loaded, the models ({modelSize}MB + 6MB) will be cached and reused when you revisit the page.<br />
|
| 184 |
+
<br />
|
| 185 |
+
Everything runs locally in your browser using <a href="https://huggingface.co/docs/transformers.js" target="_blank" rel="noreferrer" className="underline">🤗 Transformers.js</a> and ONNX Runtime Web,
|
| 186 |
+
meaning no API calls are made to a server for inference. You can even disconnect from the internet after the model has loaded!
|
| 187 |
+
</p>
|
| 188 |
+
)
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
<div className="flex flex-col w-full m-3 max-w-[520px]">
|
| 192 |
+
<span className="text-sm mb-0.5">Input audio/video</span>
|
| 193 |
+
<MediaInput
|
| 194 |
+
ref={mediaInputRef}
|
| 195 |
+
className="flex items-center border rounded-md cursor-pointer min-h-[100px] max-h-[500px] overflow-hidden"
|
| 196 |
+
onInputChange={(audio) => {
|
| 197 |
+
setResult(null);
|
| 198 |
+
setAudio(audio);
|
| 199 |
+
}}
|
| 200 |
+
onTimeUpdate={(time) => setCurrentTime(time)}
|
| 201 |
+
onMessage={onMessageReceived}
|
| 202 |
+
/>
|
| 203 |
+
</div>
|
| 204 |
+
|
| 205 |
+
<div className="relative w-full flex justify-center items-center">
|
| 206 |
+
<button
|
| 207 |
+
className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
|
| 208 |
+
onClick={handleClick}
|
| 209 |
+
disabled={status === 'running' || (status !== null && audio === null)}
|
| 210 |
+
>
|
| 211 |
+
{status === null ? 'Load model' :
|
| 212 |
+
status === 'running'
|
| 213 |
+
? 'Running...'
|
| 214 |
+
: 'Run model'
|
| 215 |
+
}
|
| 216 |
+
</button>
|
| 217 |
+
|
| 218 |
+
{status !== null &&
|
| 219 |
+
<div className='absolute right-0 bottom-0'>
|
| 220 |
+
<span className="text-xs">Language:</span>
|
| 221 |
+
<br />
|
| 222 |
+
<LanguageSelector className="border rounded-lg p-1 max-w-[100px]" language={language} setLanguage={setLanguage} />
|
| 223 |
+
</div>
|
| 224 |
+
}
|
| 225 |
+
</div>
|
| 226 |
+
|
| 227 |
+
{
|
| 228 |
+
result && time && (
|
| 229 |
+
<>
|
| 230 |
+
<div className="w-full mt-4 border rounded-md">
|
| 231 |
+
<Transcript
|
| 232 |
+
className="p-2 max-h-[200px] overflow-y-auto scrollbar-thin select-none"
|
| 233 |
+
transcript={result.transcript}
|
| 234 |
+
segments={result.segments}
|
| 235 |
+
currentTime={currentTime}
|
| 236 |
+
setCurrentTime={(time) => {
|
| 237 |
+
setCurrentTime(time);
|
| 238 |
+
mediaInputRef.current.setMediaTime(time);
|
| 239 |
+
}}
|
| 240 |
+
/>
|
| 241 |
+
</div>
|
| 242 |
+
<p className="text-sm text-end p-1">Generation time:
|
| 243 |
+
<span className="font-semibold">{(time / 1000).toLocaleString()} s</span>
|
| 244 |
+
</p>
|
| 245 |
+
<p className="text-sm text-end p-1">
|
| 246 |
+
<span className="font-semibold">{(audioLength / (time / 1000)).toFixed(2)}x transcription!</span>
|
| 247 |
+
</p>
|
| 248 |
+
</>
|
| 249 |
+
)
|
| 250 |
+
}
|
| 251 |
+
</div>
|
| 252 |
+
</div>
|
| 253 |
+
</div >
|
| 254 |
+
)
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
export default App
|
whisper-speaker-diarization/src/components/MediaInput.jsx
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
import { useState, forwardRef, useRef, useImperativeHandle, useEffect, useCallback } from 'react';
|
| 2 |
-
|
| 3 |
const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/hopper.webm';
|
| 4 |
|
| 5 |
-
const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) => {
|
| 6 |
// UI states
|
| 7 |
const [dragging, setDragging] = useState(false);
|
| 8 |
const fileInputRef = useRef(null);
|
|
@@ -89,7 +89,40 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
|
|
| 89 |
const audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16_000 });
|
| 90 |
|
| 91 |
try {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
const audioBuffer = await audioContext.decodeAudioData(buffer);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
let audio;
|
| 94 |
if (audioBuffer.numberOfChannels === 2) {
|
| 95 |
// Merge channels
|
|
@@ -145,8 +178,8 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
|
|
| 145 |
onClick={handleClick}
|
| 146 |
onDragOver={handleDragOver}
|
| 147 |
onDrop={handleDrop}
|
| 148 |
-
onDragEnter={(
|
| 149 |
-
onDragLeave={(
|
| 150 |
>
|
| 151 |
<input
|
| 152 |
type="file"
|
|
@@ -189,6 +222,13 @@ const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, ...props }, ref) =
|
|
| 189 |
</div>
|
| 190 |
);
|
| 191 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
MediaInput.displayName = 'MediaInput';
|
| 193 |
|
| 194 |
export default MediaInput;
|
|
|
|
| 1 |
import { useState, forwardRef, useRef, useImperativeHandle, useEffect, useCallback } from 'react';
|
| 2 |
+
import PropTypes from 'prop-types';
|
| 3 |
const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/hopper.webm';
|
| 4 |
|
| 5 |
+
const MediaInput = forwardRef(({ onInputChange, onTimeUpdate, onMessage, ...props }, ref) => {
|
| 6 |
// UI states
|
| 7 |
const [dragging, setDragging] = useState(false);
|
| 8 |
const fileInputRef = useRef(null);
|
|
|
|
| 89 |
const audioContext = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16_000 });
|
| 90 |
|
| 91 |
try {
|
| 92 |
+
// Start audio decoding
|
| 93 |
+
onMessage({
|
| 94 |
+
data: {
|
| 95 |
+
status: 'loading',
|
| 96 |
+
data: 'Decoding audio buffer...'
|
| 97 |
+
}
|
| 98 |
+
});
|
| 99 |
+
|
| 100 |
+
onMessage({
|
| 101 |
+
data: {
|
| 102 |
+
status: 'initiate',
|
| 103 |
+
name: 'audio-decoder',
|
| 104 |
+
file: 'audio-buffer'
|
| 105 |
+
}
|
| 106 |
+
});
|
| 107 |
+
|
| 108 |
const audioBuffer = await audioContext.decodeAudioData(buffer);
|
| 109 |
+
|
| 110 |
+
// Audio decoding complete
|
| 111 |
+
onMessage({
|
| 112 |
+
data: {
|
| 113 |
+
status: 'done',
|
| 114 |
+
name: 'audio-decoder',
|
| 115 |
+
file: 'audio-buffer'
|
| 116 |
+
}
|
| 117 |
+
});
|
| 118 |
+
|
| 119 |
+
// Audio decoding complete
|
| 120 |
+
onMessage({
|
| 121 |
+
data: {
|
| 122 |
+
status: 'loaded'
|
| 123 |
+
}
|
| 124 |
+
});
|
| 125 |
+
|
| 126 |
let audio;
|
| 127 |
if (audioBuffer.numberOfChannels === 2) {
|
| 128 |
// Merge channels
|
|
|
|
| 178 |
onClick={handleClick}
|
| 179 |
onDragOver={handleDragOver}
|
| 180 |
onDrop={handleDrop}
|
| 181 |
+
onDragEnter={() => setDragging(true)}
|
| 182 |
+
onDragLeave={() => setDragging(false)}
|
| 183 |
>
|
| 184 |
<input
|
| 185 |
type="file"
|
|
|
|
| 222 |
</div>
|
| 223 |
);
|
| 224 |
});
|
| 225 |
+
|
| 226 |
+
MediaInput.propTypes = {
|
| 227 |
+
onInputChange: PropTypes.func.isRequired,
|
| 228 |
+
onTimeUpdate: PropTypes.func.isRequired,
|
| 229 |
+
onMessage: PropTypes.func.isRequired
|
| 230 |
+
};
|
| 231 |
+
|
| 232 |
MediaInput.displayName = 'MediaInput';
|
| 233 |
|
| 234 |
export default MediaInput;
|
whisper-speaker-diarization/src/components/Progress.jsx
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
function formatBytes(size) {
|
| 2 |
const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
|
| 3 |
return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
|
|
@@ -13,3 +15,9 @@ export default function Progress({ text, percentage, total }) {
|
|
| 13 |
</div>
|
| 14 |
);
|
| 15 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import PropTypes from 'prop-types';
|
| 2 |
+
|
| 3 |
function formatBytes(size) {
|
| 4 |
const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
|
| 5 |
return +((size / Math.pow(1024, i)).toFixed(2)) * 1 + ['B', 'kB', 'MB', 'GB', 'TB'][i];
|
|
|
|
| 15 |
</div>
|
| 16 |
);
|
| 17 |
}
|
| 18 |
+
|
| 19 |
+
Progress.propTypes = {
|
| 20 |
+
text: PropTypes.string.isRequired,
|
| 21 |
+
percentage: PropTypes.number,
|
| 22 |
+
total: PropTypes.number
|
| 23 |
+
};
|
whisper-speaker-diarization/src/worker.js
CHANGED
|
@@ -1,124 +1,272 @@
|
|
| 1 |
-
|
| 2 |
-
import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@
|
| 3 |
-
|
| 4 |
-
const PER_DEVICE_CONFIG = {
|
| 5 |
-
webgpu: {
|
| 6 |
-
dtype: {
|
| 7 |
-
encoder_model: 'fp32',
|
| 8 |
-
decoder_model_merged: 'q4',
|
| 9 |
-
},
|
| 10 |
-
device: 'webgpu',
|
| 11 |
-
},
|
| 12 |
-
wasm: {
|
| 13 |
-
dtype: 'q8',
|
| 14 |
-
device: 'wasm',
|
| 15 |
-
},
|
| 16 |
-
};
|
| 17 |
-
|
| 18 |
-
/**
|
| 19 |
-
* This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
|
| 20 |
-
*/
|
| 21 |
-
class PipelineSingeton {
|
| 22 |
-
static asr_model_id = 'onnx-community/whisper-base_timestamped';
|
| 23 |
-
static asr_instance = null;
|
| 24 |
-
|
| 25 |
-
static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
|
| 26 |
-
static segmentation_instance = null;
|
| 27 |
-
static segmentation_processor = null;
|
| 28 |
-
|
| 29 |
-
static async getInstance(progress_callback = null, device = 'webgpu') {
|
| 30 |
-
this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
|
| 31 |
-
...PER_DEVICE_CONFIG[device],
|
| 32 |
-
progress_callback,
|
| 33 |
-
});
|
| 34 |
-
|
| 35 |
-
this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
|
| 36 |
-
progress_callback,
|
| 37 |
-
});
|
| 38 |
-
this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
|
| 39 |
-
// NOTE: WebGPU is not currently supported for this model
|
| 40 |
-
// See https://github.com/microsoft/onnxruntime/issues/21386
|
| 41 |
-
device: 'wasm',
|
| 42 |
-
dtype: 'fp32',
|
| 43 |
-
progress_callback,
|
| 44 |
-
});
|
| 45 |
-
|
| 46 |
-
return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
|
| 47 |
-
}
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
async function load({ device }) {
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
}
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import { pipeline, AutoProcessor, AutoModelForAudioFrameClassification } from '@huggingface/transformers';
|
| 3 |
+
|
| 4 |
+
const PER_DEVICE_CONFIG = {
|
| 5 |
+
webgpu: {
|
| 6 |
+
dtype: {
|
| 7 |
+
encoder_model: 'fp32',
|
| 8 |
+
decoder_model_merged: 'q4',
|
| 9 |
+
},
|
| 10 |
+
device: 'webgpu',
|
| 11 |
+
},
|
| 12 |
+
wasm: {
|
| 13 |
+
dtype: 'q8',
|
| 14 |
+
device: 'wasm',
|
| 15 |
+
},
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
/**
|
| 19 |
+
* This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
|
| 20 |
+
*/
|
| 21 |
+
class PipelineSingeton {
|
| 22 |
+
static asr_model_id = 'onnx-community/whisper-base_timestamped';
|
| 23 |
+
static asr_instance = null;
|
| 24 |
+
|
| 25 |
+
static segmentation_model_id = 'onnx-community/pyannote-segmentation-3.0';
|
| 26 |
+
static segmentation_instance = null;
|
| 27 |
+
static segmentation_processor = null;
|
| 28 |
+
|
| 29 |
+
static async getInstance(progress_callback = null, device = 'webgpu') {
|
| 30 |
+
this.asr_instance ??= pipeline('automatic-speech-recognition', this.asr_model_id, {
|
| 31 |
+
...PER_DEVICE_CONFIG[device],
|
| 32 |
+
progress_callback,
|
| 33 |
+
});
|
| 34 |
+
|
| 35 |
+
this.segmentation_processor ??= AutoProcessor.from_pretrained(this.segmentation_model_id, {
|
| 36 |
+
progress_callback,
|
| 37 |
+
});
|
| 38 |
+
this.segmentation_instance ??= AutoModelForAudioFrameClassification.from_pretrained(this.segmentation_model_id, {
|
| 39 |
+
// NOTE: WebGPU is not currently supported for this model
|
| 40 |
+
// See https://github.com/microsoft/onnxruntime/issues/21386
|
| 41 |
+
device: 'wasm',
|
| 42 |
+
dtype: 'fp32',
|
| 43 |
+
progress_callback,
|
| 44 |
+
});
|
| 45 |
+
|
| 46 |
+
return Promise.all([this.asr_instance, this.segmentation_processor, this.segmentation_instance]);
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
async function load({ device }) {
|
| 51 |
+
try {
|
| 52 |
+
const message = {
|
| 53 |
+
status: 'loading',
|
| 54 |
+
data: `Loading models (${device})...`
|
| 55 |
+
};
|
| 56 |
+
self.postMessage(message);
|
| 57 |
+
|
| 58 |
+
const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance(x => {
|
| 59 |
+
// We also add a progress callback to the pipeline so that we can
|
| 60 |
+
// track model loading.
|
| 61 |
+
self.postMessage(x);
|
| 62 |
+
}, device);
|
| 63 |
+
|
| 64 |
+
if (device === 'webgpu') {
|
| 65 |
+
const warmupMessage = {
|
| 66 |
+
status: 'loading',
|
| 67 |
+
data: 'Compiling shaders and warming up model...'
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
self.postMessage(warmupMessage);
|
| 71 |
+
|
| 72 |
+
await transcriber(new Float32Array(16_000), {
|
| 73 |
+
language: 'en',
|
| 74 |
+
});
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
self.postMessage({ status: 'loaded' });
|
| 78 |
+
} catch (error) {
|
| 79 |
+
console.error('Loading error:', error);
|
| 80 |
+
const errorMessage = {
|
| 81 |
+
status: 'error',
|
| 82 |
+
error: error.message || 'Failed to load models'
|
| 83 |
+
};
|
| 84 |
+
self.postMessage(errorMessage);
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
async function segment(processor, model, audio) {
|
| 89 |
+
try {
|
| 90 |
+
// Report start of segmentation
|
| 91 |
+
self.postMessage({
|
| 92 |
+
status: 'transcribe-progress',
|
| 93 |
+
data: {
|
| 94 |
+
task: 'segmentation',
|
| 95 |
+
progress: 0,
|
| 96 |
+
total: audio.length
|
| 97 |
+
}
|
| 98 |
+
});
|
| 99 |
+
|
| 100 |
+
// Process audio in chunks to show progress
|
| 101 |
+
const inputs = await processor(audio);
|
| 102 |
+
|
| 103 |
+
// Report segmentation feature extraction progress
|
| 104 |
+
self.postMessage({
|
| 105 |
+
status: 'transcribe-progress',
|
| 106 |
+
data: {
|
| 107 |
+
task: 'segmentation',
|
| 108 |
+
progress: 50,
|
| 109 |
+
total: audio.length
|
| 110 |
+
}
|
| 111 |
+
});
|
| 112 |
+
|
| 113 |
+
const { logits } = await model(inputs);
|
| 114 |
+
|
| 115 |
+
// Report segmentation completion
|
| 116 |
+
self.postMessage({
|
| 117 |
+
status: 'transcribe-progress',
|
| 118 |
+
data: {
|
| 119 |
+
task: 'segmentation',
|
| 120 |
+
progress: 100,
|
| 121 |
+
total: audio.length
|
| 122 |
+
}
|
| 123 |
+
});
|
| 124 |
+
|
| 125 |
+
// Start diarization
|
| 126 |
+
self.postMessage({
|
| 127 |
+
status: 'transcribe-progress',
|
| 128 |
+
data: {
|
| 129 |
+
task: 'diarization',
|
| 130 |
+
progress: 0,
|
| 131 |
+
total: audio.length
|
| 132 |
+
}
|
| 133 |
+
});
|
| 134 |
+
|
| 135 |
+
const segments = processor.post_process_speaker_diarization(logits, audio.length)[0];
|
| 136 |
+
|
| 137 |
+
// Attach labels and report diarization completion
|
| 138 |
+
for (const segment of segments) {
|
| 139 |
+
segment.label = model.config.id2label[segment.id];
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
self.postMessage({
|
| 143 |
+
status: 'transcribe-progress',
|
| 144 |
+
data: {
|
| 145 |
+
task: 'diarization',
|
| 146 |
+
progress: 100,
|
| 147 |
+
total: audio.length
|
| 148 |
+
}
|
| 149 |
+
});
|
| 150 |
+
|
| 151 |
+
return segments;
|
| 152 |
+
} catch (error) {
|
| 153 |
+
console.error('Segmentation error:', error);
|
| 154 |
+
return [{
|
| 155 |
+
id: 0,
|
| 156 |
+
start: 0,
|
| 157 |
+
end: (audio.length / 480016) * 30,
|
| 158 |
+
label: 'SPEAKER_00',
|
| 159 |
+
confidence: 1.0
|
| 160 |
+
}];
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
async function run({ audio, language }) {
|
| 165 |
+
try {
|
| 166 |
+
const [transcriber, segmentation_processor, segmentation_model] = await PipelineSingeton.getInstance();
|
| 167 |
+
|
| 168 |
+
const audioLengthSeconds = (audio.length / 16000);
|
| 169 |
+
|
| 170 |
+
// Initialize transcription progress
|
| 171 |
+
self.postMessage({
|
| 172 |
+
status: 'transcribe-progress',
|
| 173 |
+
data: {
|
| 174 |
+
task: 'transcription',
|
| 175 |
+
progress: 0,
|
| 176 |
+
total: audio.length
|
| 177 |
+
}
|
| 178 |
+
});
|
| 179 |
+
|
| 180 |
+
const start = performance.now();
|
| 181 |
+
// Process in 30-second chunks
|
| 182 |
+
const CHUNK_SIZE = 3 * 30 * 16000; // 30 seconds * 16000 samples/second
|
| 183 |
+
const numChunks = Math.ceil(audio.length / CHUNK_SIZE);
|
| 184 |
+
let transcriptResults = [];
|
| 185 |
+
|
| 186 |
+
for (let i = 0; i < numChunks; i++) {
|
| 187 |
+
const start = i * CHUNK_SIZE;
|
| 188 |
+
const end = Math.min((i + 1) * CHUNK_SIZE, audio.length);
|
| 189 |
+
const chunk = audio.slice(start, end);
|
| 190 |
+
|
| 191 |
+
// Process chunk
|
| 192 |
+
const chunkResult = await transcriber(chunk, {
|
| 193 |
+
language,
|
| 194 |
+
return_timestamps: 'word',
|
| 195 |
+
chunk_length_s: 30,
|
| 196 |
+
});
|
| 197 |
+
const progressMessage = {
|
| 198 |
+
status: 'transcribe-progress',
|
| 199 |
+
data: {
|
| 200 |
+
task: 'transcription',
|
| 201 |
+
progress: Math.round((i+1) / numChunks * 100),
|
| 202 |
+
total: audio.length
|
| 203 |
+
}
|
| 204 |
+
};
|
| 205 |
+
self.postMessage(progressMessage);
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
// Adjust timestamps for this chunk
|
| 209 |
+
if (chunkResult.chunks) {
|
| 210 |
+
chunkResult.chunks.forEach(chunk => {
|
| 211 |
+
if (chunk.timestamp) {
|
| 212 |
+
chunk.timestamp[0] += start / 16000; // Convert samples to seconds
|
| 213 |
+
chunk.timestamp[1] += start / 16000;
|
| 214 |
+
}
|
| 215 |
+
});
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
transcriptResults.push(chunkResult);
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// Combine results
|
| 222 |
+
const transcript = {
|
| 223 |
+
text: transcriptResults.map(r => r.text).join(''),
|
| 224 |
+
chunks: transcriptResults.flatMap(r => r.chunks || [])
|
| 225 |
+
};
|
| 226 |
+
|
| 227 |
+
// Run segmentation in parallel with the last chunk
|
| 228 |
+
const segments = await segment(segmentation_processor, segmentation_model, audio);
|
| 229 |
+
|
| 230 |
+
// Ensure transcription shows as complete
|
| 231 |
+
self.postMessage({
|
| 232 |
+
status: 'transcribe-progress',
|
| 233 |
+
data: {
|
| 234 |
+
task: 'transcription',
|
| 235 |
+
progress: 100,
|
| 236 |
+
total: audio.length
|
| 237 |
+
}
|
| 238 |
+
});
|
| 239 |
+
const end = performance.now();
|
| 240 |
+
|
| 241 |
+
const completeMessage = {
|
| 242 |
+
status: 'complete',
|
| 243 |
+
result: { transcript, segments },
|
| 244 |
+
audio_length: audioLengthSeconds,
|
| 245 |
+
time: end - start
|
| 246 |
+
};
|
| 247 |
+
self.postMessage(completeMessage);
|
| 248 |
+
} catch (error) {
|
| 249 |
+
console.error('Processing error:', error);
|
| 250 |
+
const errorMessage = {
|
| 251 |
+
status: 'error',
|
| 252 |
+
error: error.message || 'Failed to process audio'
|
| 253 |
+
};
|
| 254 |
+
console.log('Worker sending error:', errorMessage);
|
| 255 |
+
self.postMessage(errorMessage);
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// Listen for messages from the main thread
|
| 260 |
+
self.addEventListener('message', async (e) => {
|
| 261 |
+
const { type, data } = e.data;
|
| 262 |
+
|
| 263 |
+
switch (type) {
|
| 264 |
+
case 'load':
|
| 265 |
+
load(data);
|
| 266 |
+
break;
|
| 267 |
+
|
| 268 |
+
case 'run':
|
| 269 |
+
run(data);
|
| 270 |
+
break;
|
| 271 |
+
}
|
| 272 |
+
});
|