File size: 2,167 Bytes
5e1b738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import { pipeline } from "@huggingface/transformers";

const PER_DEVICE_CONFIG = {
  webgpu: {
    dtype: {
      encoder_model: "fp32",
      decoder_model_merged: "q4",
    },
    device: "webgpu",
  },
  wasm: {
    dtype: "q8",
    device: "wasm",
  },
};

/**

 * This class uses the Singleton pattern to ensure that only one instance of the model is loaded.

 */
class PipelineSingeton {
  static model_id = "onnx-community/whisper-base_timestamped";
  static instance = null;

  static async getInstance(progress_callback = null, device = "webgpu") {
    if (!this.instance) {
      this.instance = pipeline("automatic-speech-recognition", this.model_id, {
        ...PER_DEVICE_CONFIG[device],
        progress_callback,
      });
    }
    return this.instance;
  }
}

async function load({ device }) {
  self.postMessage({
    status: "loading",
    data: `Loading model (${device})...`,
  });

  // Load the pipeline and save it for future use.
  const transcriber = await PipelineSingeton.getInstance((x) => {
    // We also add a progress callback to the pipeline so that we can
    // track model loading.
    self.postMessage(x);
  }, device);

  if (device === "webgpu") {
    self.postMessage({
      status: "loading",
      data: "Compiling shaders and warming up model...",
    });

    await transcriber(new Float32Array(16_000), {
      language: "en",
    });
  }

  self.postMessage({ status: "ready" });
}

async function run({ audio, language }) {
  const transcriber = await PipelineSingeton.getInstance();

  // Read and preprocess image
  const start = performance.now();

  const result = await transcriber(audio, {
    language,
    return_timestamps: "word",
    chunk_length_s: 30,
  });

  const end = performance.now();

  self.postMessage({ status: "complete", result, time: end - start });
}

// Listen for messages from the main thread
self.addEventListener("message", async (e) => {
  const { type, data } = e.data;

  switch (type) {
    case "load":
      load(data);
      break;

    case "run":
      run(data);
      break;
  }
});