DocWolle commited on
Commit
f813885
·
verified ·
1 Parent(s): c6f3ee7

EUROPEAN_UNION model supports transcription for predefined language. Supported: all EU languages and norwegian

Browse files
Generate_tflite_for_whisper_base_EUROPEAN_UNION_version.ipynb ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "c5g9NTF_Ixad"
7
+ },
8
+ "source": [
9
+ "##Install Tranformers and datasets"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "id": "w4VPaSlnHUvT"
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "!pip install transformers==4.33.0\n",
21
+ "!pip install tensorflow==2.14.0"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {
28
+ "id": "ClniiYCWHK4b"
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "! pip install datasets"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "metadata": {
38
+ "id": "pljpioLsJOtb"
39
+ },
40
+ "source": [
41
+ "##Load pre trained TF Whisper Base model"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {
48
+ "id": "BJNOxn5vHaGi"
49
+ },
50
+ "outputs": [],
51
+ "source": [
52
+ "import tensorflow as tf\n",
53
+ "from transformers import TFWhisperModel, WhisperFeatureExtractor\n",
54
+ "from datasets import load_dataset\n",
55
+ "\n",
56
+ "model = TFWhisperModel.from_pretrained(\"openai/whisper-base\")\n",
57
+ "feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-base\")\n",
58
+ "\n",
59
+ "ds = load_dataset(\"google/fleurs\", \"fr_fr\", split=\"test\")\n",
60
+ "inputs = feature_extractor(\n",
61
+ " ds[0][\"audio\"][\"array\"], sampling_rate=ds[0][\"audio\"][\"sampling_rate\"], return_tensors=\"tf\"\n",
62
+ ")\n",
63
+ "input_features = inputs.input_features\n",
64
+ "print(input_features)\n",
65
+ "decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id\n",
66
+ "last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state\n",
67
+ "list(last_hidden_state.shape)"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "markdown",
72
+ "metadata": {
73
+ "id": "W9XP25uhJl44"
74
+ },
75
+ "source": [
76
+ "##Generate Saved model"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {
83
+ "id": "vpYwMmgyHf0B"
84
+ },
85
+ "outputs": [],
86
+ "source": [
87
+ "model.save('/content/tf_whisper_saved')"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "metadata": {
93
+ "id": "TY_79jFEJYyJ"
94
+ },
95
+ "source": [
96
+ "##Convert saved model to TFLite model"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {
103
+ "id": "owez2zvzHl-p"
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "import tensorflow as tf\n",
108
+ "\n",
109
+ "saved_model_dir = '/content/tf_whisper_saved'\n",
110
+ "tflite_model_path = 'whisper.tflite'\n",
111
+ "\n",
112
+ "# Convert the model\n",
113
+ "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
114
+ "converter.target_spec.supported_ops = [\n",
115
+ " tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.\n",
116
+ " tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.\n",
117
+ "]\n",
118
+ "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
119
+ "tflite_model = converter.convert()\n",
120
+ "\n",
121
+ "# Save the model\n",
122
+ "with open(tflite_model_path, 'wb') as f:\n",
123
+ " f.write(tflite_model)"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "metadata": {
130
+ "id": "tFkzUrjIbNcH",
131
+ "colab": {
132
+ "base_uri": "https://localhost:8080/"
133
+ },
134
+ "outputId": "0611db92-81d4-4473-9d21-ccc19da5d5c5"
135
+ },
136
+ "outputs": [
137
+ {
138
+ "output_type": "stream",
139
+ "name": "stdout",
140
+ "text": [
141
+ "total 73812\n",
142
+ "drwxr-xr-x 1 root root 4096 Mar 7 19:47 \u001b[0m\u001b[01;34m.\u001b[0m/\n",
143
+ "drwxr-xr-x 1 root root 4096 Mar 7 19:39 \u001b[01;34m..\u001b[0m/\n",
144
+ "drwxr-xr-x 4 root root 4096 Mar 6 14:29 \u001b[01;34m.config\u001b[0m/\n",
145
+ "drwxr-xr-x 1 root root 4096 Mar 6 14:29 \u001b[01;34msample_data\u001b[0m/\n",
146
+ "drwxr-xr-x 4 root root 4096 Mar 7 21:54 \u001b[01;34mtf_whisper_saved\u001b[0m/\n",
147
+ "-rw-r--r-- 1 root root 75560432 Mar 7 21:55 whisper.tflite\n"
148
+ ]
149
+ }
150
+ ],
151
+ "source": [
152
+ "%ls -la"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "markdown",
157
+ "metadata": {
158
+ "id": "fpEnWZt7iQJK"
159
+ },
160
+ "source": [
161
+ "##Evaluate TF model"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "metadata": {
168
+ "id": "-RuFFohHg2ho",
169
+ "colab": {
170
+ "base_uri": "https://localhost:8080/"
171
+ },
172
+ "outputId": "45f8972c-6e2f-4c60-cde4-090e6572d389"
173
+ },
174
+ "outputs": [
175
+ {
176
+ "output_type": "stream",
177
+ "name": "stderr",
178
+ "text": [
179
+ "/home/wolfgang/.local/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
180
+ " warnings.warn(\n",
181
+ "All PyTorch model weights were used when initializing TFWhisperForConditionalGeneration.\n",
182
+ "\n",
183
+ "All the weights of TFWhisperForConditionalGeneration were initialized from the PyTorch model.\n",
184
+ "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFWhisperForConditionalGeneration for predictions without further training.\n",
185
+ "It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n"
186
+ ]
187
+ },
188
+ {
189
+ "output_type": "execute_result",
190
+ "data": {
191
+ "text/plain": [
192
+ "'<|startoftranscript|><|en|><|transcribe|><|notimestamps|> The accident took place in a mountainous area, and it seemed that this was caused by a bad old man.<|endoftext|>'"
193
+ ]
194
+ },
195
+ "metadata": {},
196
+ "execution_count": 6
197
+ }
198
+ ],
199
+ "source": [
200
+ "import tensorflow as tf\n",
201
+ "from transformers import WhisperProcessor, TFWhisperForConditionalGeneration\n",
202
+ "from datasets import load_dataset\n",
203
+ "\n",
204
+ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-base\")\n",
205
+ "model = TFWhisperForConditionalGeneration.from_pretrained(\"openai/whisper-base\")\n",
206
+ "\n",
207
+ "ds = load_dataset(\"google/fleurs\", \"fr_fr\", split=\"test\")\n",
208
+ "\n",
209
+ "inputs = processor(ds[0][\"audio\"][\"array\"], return_tensors=\"tf\")\n",
210
+ "input_features = inputs.input_features\n",
211
+ "\n",
212
+ "generated_ids = model.generate(input_features)\n",
213
+ "\n",
214
+ "transcription = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]\n",
215
+ "transcription"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "metadata": {
221
+ "id": "U-eKuy_cG4u0"
222
+ },
223
+ "source": [
224
+ "## Evaluate TF Lite model (naive)\n",
225
+ "\n",
226
+ "We can load the model as defined above... but the model is useless on its own. Generation is much more complex that a model forward pass."
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "metadata": {
233
+ "id": "wnfHirgyG0W4"
234
+ },
235
+ "outputs": [],
236
+ "source": [
237
+ "tflite_model_path = 'whisper.tflite'\n",
238
+ "interpreter = tf.lite.Interpreter(tflite_model_path)"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "markdown",
243
+ "metadata": {
244
+ "id": "a8VJQuHJKzl4"
245
+ },
246
+ "source": [
247
+ "## Create generation-enabled TF Lite model\n",
248
+ "\n",
249
+ "The solution consists in defining a model whose serving function is the generation call. Here's an example of how to do it:"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "markdown",
254
+ "metadata": {
255
+ "id": "JmIgqWVgVBZN"
256
+ },
257
+ "source": [
258
+ "Now with monkey-patch for fixing NaN errors with -inf values"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "metadata": {
265
+ "id": "e5P8s66yU7Kv"
266
+ },
267
+ "outputs": [],
268
+ "source": [
269
+ "import tensorflow as tf\n",
270
+ "import numpy as np\n",
271
+ "from transformers import TFForceTokensLogitsProcessor, TFLogitsProcessor\n",
272
+ "from typing import List, Optional, Union, Any\n",
273
+ "\n",
274
+ "# Patching methods of class TFForceTokensLogitsProcessor(TFLogitsProcessor):\n",
275
+ "\n",
276
+ "def my__init__(self, force_token_map: List[List[int]]):\n",
277
+ " force_token_map = dict(force_token_map)\n",
278
+ " # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the\n",
279
+ " # index of the array corresponds to the index of the token to be forced, for XLA compatibility.\n",
280
+ " # Indexes without forced tokens will have an negative value.\n",
281
+ " force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1\n",
282
+ " for index, token in force_token_map.items():\n",
283
+ " if token is not None:\n",
284
+ " force_token_array[index] = token\n",
285
+ " self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32)\n",
286
+ "\n",
287
+ "def my__call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:\n",
288
+ " def _force_token(generation_idx):\n",
289
+ " batch_size = scores.shape[0]\n",
290
+ " current_token = self.force_token_array[generation_idx]\n",
291
+ "\n",
292
+ " # Original code below generates NaN values when the model is exported to tflite\n",
293
+ " # it just needs to be a negative number so that the forced token's value of 0 is the largest\n",
294
+ " # so it will get chosen\n",
295
+ " #new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float(\"inf\")\n",
296
+ " new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float(1)\n",
297
+ " indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)\n",
298
+ " updates = tf.zeros((batch_size,), dtype=scores.dtype)\n",
299
+ " new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)\n",
300
+ " return new_scores\n",
301
+ "\n",
302
+ " scores = tf.cond(\n",
303
+ " tf.greater_equal(cur_len, tf.shape(self.force_token_array)[0]),\n",
304
+ " # If the current length is geq than the length of force_token_array, the processor does nothing.\n",
305
+ " lambda: tf.identity(scores),\n",
306
+ " # Otherwise, it may force a certain token.\n",
307
+ " lambda: tf.cond(\n",
308
+ " tf.greater_equal(self.force_token_array[cur_len], 0),\n",
309
+ " # Only valid (positive) tokens are forced\n",
310
+ " lambda: _force_token(cur_len),\n",
311
+ " # Otherwise, the processor does nothing.\n",
312
+ " lambda: scores,\n",
313
+ " ),\n",
314
+ " )\n",
315
+ " return scores\n",
316
+ "\n",
317
+ "TFForceTokensLogitsProcessor.__init__ = my__init__\n",
318
+ "TFForceTokensLogitsProcessor.__call__ = my__call__"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "metadata": {
325
+ "id": "rIkUCdiyU7ZT"
326
+ },
327
+ "outputs": [],
328
+ "source": [
329
+ "import tensorflow as tf\n",
330
+ "\n",
331
+ "class GenerateModel(tf.Module):\n",
332
+ " def __init__(self, model):\n",
333
+ " super(GenerateModel, self).__init__()\n",
334
+ " self.model = model\n",
335
+ "\n",
336
+ " @tf.function(\n",
337
+ " input_signature=[\n",
338
+ " tf.TensorSpec((1, 80, 3000), tf.float32, name=\"input_features\"),\n",
339
+ " tf.TensorSpec((), tf.int32, name=\"lang_token\"),\n",
340
+ " ],\n",
341
+ " )\n",
342
+ " def transcribe_lang(self, input_features, lang_token):\n",
343
+ " if lang_token == 50259:\n",
344
+ " outputs = self.model.generate(\n",
345
+ " input_features,\n",
346
+ " max_new_tokens=450,\n",
347
+ " return_dict_in_generate=True,\n",
348
+ " forced_decoder_ids=[[1, 50259], [2, 50359], [3, 50363]],\n",
349
+ " )\n",
350
+ "\n",
351
+ " elif lang_token == 50261:\n",
352
+ " outputs = self.model.generate(\n",
353
+ " input_features,\n",
354
+ " max_new_tokens=450,\n",
355
+ " return_dict_in_generate=True,\n",
356
+ " forced_decoder_ids=[[1, 50261], [2, 50359], [3, 50363]],\n",
357
+ " )\n",
358
+ "\n",
359
+ " elif lang_token == 50262:\n",
360
+ " outputs = self.model.generate(\n",
361
+ " input_features,\n",
362
+ " max_new_tokens=450,\n",
363
+ " return_dict_in_generate=True,\n",
364
+ " forced_decoder_ids=[[1, 50262], [2, 50359], [3, 50363]],\n",
365
+ " )\n",
366
+ "\n",
367
+ " elif lang_token == 50265:\n",
368
+ " outputs = self.model.generate(\n",
369
+ " input_features,\n",
370
+ " max_new_tokens=450,\n",
371
+ " return_dict_in_generate=True,\n",
372
+ " forced_decoder_ids=[[1, 50265], [2, 50359], [3, 50363]],\n",
373
+ " )\n",
374
+ "\n",
375
+ " elif lang_token == 50267:\n",
376
+ " outputs = self.model.generate(\n",
377
+ " input_features,\n",
378
+ " max_new_tokens=450,\n",
379
+ " return_dict_in_generate=True,\n",
380
+ " forced_decoder_ids=[[1, 50267], [2, 50359], [3, 50363]],\n",
381
+ " )\n",
382
+ "\n",
383
+ " elif lang_token == 50268:\n",
384
+ " outputs = self.model.generate(\n",
385
+ " input_features,\n",
386
+ " max_new_tokens=450,\n",
387
+ " return_dict_in_generate=True,\n",
388
+ " forced_decoder_ids=[[1, 50268], [2, 50359], [3, 50363]],\n",
389
+ " )\n",
390
+ "\n",
391
+ " elif lang_token == 50269:\n",
392
+ " outputs = self.model.generate(\n",
393
+ " input_features,\n",
394
+ " max_new_tokens=450,\n",
395
+ " return_dict_in_generate=True,\n",
396
+ " forced_decoder_ids=[[1, 50269], [2, 50359], [3, 50363]],\n",
397
+ " )\n",
398
+ "\n",
399
+ " elif lang_token == 50271:\n",
400
+ " outputs = self.model.generate(\n",
401
+ " input_features,\n",
402
+ " max_new_tokens=450,\n",
403
+ " return_dict_in_generate=True,\n",
404
+ " forced_decoder_ids=[[1, 50271], [2, 50359], [3, 50363]],\n",
405
+ " )\n",
406
+ "\n",
407
+ " elif lang_token == 50273:\n",
408
+ " outputs = self.model.generate(\n",
409
+ " input_features,\n",
410
+ " max_new_tokens=450,\n",
411
+ " return_dict_in_generate=True,\n",
412
+ " forced_decoder_ids=[[1, 50273], [2, 50359], [3, 50363]],\n",
413
+ " )\n",
414
+ "\n",
415
+ " elif lang_token == 50274:\n",
416
+ " outputs = self.model.generate(\n",
417
+ " input_features,\n",
418
+ " max_new_tokens=450,\n",
419
+ " return_dict_in_generate=True,\n",
420
+ " forced_decoder_ids=[[1, 50274], [2, 50359], [3, 50363]],\n",
421
+ " )\n",
422
+ "\n",
423
+ " elif lang_token == 50277:\n",
424
+ " outputs = self.model.generate(\n",
425
+ " input_features,\n",
426
+ " max_new_tokens=450,\n",
427
+ " return_dict_in_generate=True,\n",
428
+ " forced_decoder_ids=[[1, 50277], [2, 50359], [3, 50363]],\n",
429
+ " )\n",
430
+ "\n",
431
+ " elif lang_token == 50281:\n",
432
+ " outputs = self.model.generate(\n",
433
+ " input_features,\n",
434
+ " max_new_tokens=450,\n",
435
+ " return_dict_in_generate=True,\n",
436
+ " forced_decoder_ids=[[1, 50281], [2, 50359], [3, 50363]],\n",
437
+ " )\n",
438
+ "\n",
439
+ " elif lang_token == 50283:\n",
440
+ " outputs = self.model.generate(\n",
441
+ " input_features,\n",
442
+ " max_new_tokens=450,\n",
443
+ " return_dict_in_generate=True,\n",
444
+ " forced_decoder_ids=[[1, 50283], [2, 50359], [3, 50363]],\n",
445
+ " )\n",
446
+ "\n",
447
+ " elif lang_token == 50284:\n",
448
+ " outputs = self.model.generate(\n",
449
+ " input_features,\n",
450
+ " max_new_tokens=450,\n",
451
+ " return_dict_in_generate=True,\n",
452
+ " forced_decoder_ids=[[1, 50284], [2, 50359], [3, 50363]],\n",
453
+ " )\n",
454
+ "\n",
455
+ " elif lang_token == 50285:\n",
456
+ " outputs = self.model.generate(\n",
457
+ " input_features,\n",
458
+ " max_new_tokens=450,\n",
459
+ " return_dict_in_generate=True,\n",
460
+ " forced_decoder_ids=[[1, 50285], [2, 50359], [3, 50363]],\n",
461
+ " )\n",
462
+ "\n",
463
+ " elif lang_token == 50286:\n",
464
+ " outputs = self.model.generate(\n",
465
+ " input_features,\n",
466
+ " max_new_tokens=450,\n",
467
+ " return_dict_in_generate=True,\n",
468
+ " forced_decoder_ids=[[1, 50286], [2, 50359], [3, 50363]],\n",
469
+ " )\n",
470
+ "\n",
471
+ " elif lang_token == 50288:\n",
472
+ " outputs = self.model.generate(\n",
473
+ " input_features,\n",
474
+ " max_new_tokens=450,\n",
475
+ " return_dict_in_generate=True,\n",
476
+ " forced_decoder_ids=[[1, 50288], [2, 50359], [3, 50363]],\n",
477
+ " )\n",
478
+ "\n",
479
+ " elif lang_token == 50291:\n",
480
+ " outputs = self.model.generate(\n",
481
+ " input_features,\n",
482
+ " max_new_tokens=450,\n",
483
+ " return_dict_in_generate=True,\n",
484
+ " forced_decoder_ids=[[1, 50291], [2, 50359], [3, 50363]],\n",
485
+ " )\n",
486
+ "\n",
487
+ " elif lang_token == 50292:\n",
488
+ " outputs = self.model.generate(\n",
489
+ " input_features,\n",
490
+ " max_new_tokens=450,\n",
491
+ " return_dict_in_generate=True,\n",
492
+ " forced_decoder_ids=[[1, 50292], [2, 50359], [3, 50363]],\n",
493
+ " )\n",
494
+ "\n",
495
+ " elif lang_token == 50293:\n",
496
+ " outputs = self.model.generate(\n",
497
+ " input_features,\n",
498
+ " max_new_tokens=450,\n",
499
+ " return_dict_in_generate=True,\n",
500
+ " forced_decoder_ids=[[1, 50293], [2, 50359], [3, 50363]],\n",
501
+ " )\n",
502
+ "\n",
503
+ " elif lang_token == 50298:\n",
504
+ " outputs = self.model.generate(\n",
505
+ " input_features,\n",
506
+ " max_new_tokens=450,\n",
507
+ " return_dict_in_generate=True,\n",
508
+ " forced_decoder_ids=[[1, 50298], [2, 50359], [3, 50363]],\n",
509
+ " )\n",
510
+ "\n",
511
+ " elif lang_token == 50301:\n",
512
+ " outputs = self.model.generate(\n",
513
+ " input_features,\n",
514
+ " max_new_tokens=450,\n",
515
+ " return_dict_in_generate=True,\n",
516
+ " forced_decoder_ids=[[1, 50301], [2, 50359], [3, 50363]],\n",
517
+ " )\n",
518
+ "\n",
519
+ " elif lang_token == 50305:\n",
520
+ " outputs = self.model.generate(\n",
521
+ " input_features,\n",
522
+ " max_new_tokens=450,\n",
523
+ " return_dict_in_generate=True,\n",
524
+ " forced_decoder_ids=[[1, 50305], [2, 50359], [3, 50363]],\n",
525
+ " )\n",
526
+ "\n",
527
+ " elif lang_token == 50307:\n",
528
+ " outputs = self.model.generate(\n",
529
+ " input_features,\n",
530
+ " max_new_tokens=450,\n",
531
+ " return_dict_in_generate=True,\n",
532
+ " forced_decoder_ids=[[1, 50307], [2, 50359], [3, 50363]],\n",
533
+ " )\n",
534
+ "\n",
535
+ " elif lang_token == 50343:\n",
536
+ " outputs = self.model.generate(\n",
537
+ " input_features,\n",
538
+ " max_new_tokens=450,\n",
539
+ " return_dict_in_generate=True,\n",
540
+ " forced_decoder_ids=[[1, 50343], [2, 50359], [3, 50363]],\n",
541
+ " )\n",
542
+ "\n",
543
+ " elif lang_token == 50345:\n",
544
+ " outputs = self.model.generate(\n",
545
+ " input_features,\n",
546
+ " max_new_tokens=450,\n",
547
+ " return_dict_in_generate=True,\n",
548
+ " forced_decoder_ids=[[1, 50345], [2, 50359], [3, 50363]],\n",
549
+ " )\n",
550
+ "\n",
551
+ " else:\n",
552
+ " outputs = self.model.generate(\n",
553
+ " input_features,\n",
554
+ " max_new_tokens=450, # change as needed\n",
555
+ " return_dict_in_generate=True,\n",
556
+ " forced_decoder_ids=[[2, 50359], [3, 50363]],\n",
557
+ " )\n",
558
+ " return {\"sequences\": outputs[\"sequences\"]}\n",
559
+ "\n",
560
+ "\n",
561
+ " @tf.function(\n",
562
+ " input_signature=[\n",
563
+ " tf.TensorSpec((1, 80, 3000), tf.float32, name=\"input_features\"),\n",
564
+ " ],\n",
565
+ " )\n",
566
+ " def transcribe(self, input_features):\n",
567
+ " outputs = self.model.generate(\n",
568
+ " input_features,\n",
569
+ " max_new_tokens=450, # change as needed\n",
570
+ " return_dict_in_generate=True,\n",
571
+ " forced_decoder_ids=[[2, 50359], [3, 50363]],\n",
572
+ " )\n",
573
+ " return {\"sequences\": outputs[\"sequences\"]}\n",
574
+ "\n",
575
+ " @tf.function(\n",
576
+ " input_signature=[\n",
577
+ " tf.TensorSpec((1, 80, 3000), tf.float32, name=\"input_features\"),\n",
578
+ " ],\n",
579
+ " )\n",
580
+ " def translate(self, input_features):\n",
581
+ " outputs = self.model.generate(\n",
582
+ " input_features,\n",
583
+ " max_new_tokens=450, # change as needed\n",
584
+ " return_dict_in_generate=True,\n",
585
+ " forced_decoder_ids=[[2, 50358], [3, 50363]],\n",
586
+ " )\n",
587
+ " return {\"sequences\": outputs[\"sequences\"]}\n",
588
+ "\n",
589
+ "# Assuming `model` is already defined and loaded\n",
590
+ "saved_model_dir = '/content/tf_whisper_saved'\n",
591
+ "tflite_model_path = 'whisper.tflite'\n",
592
+ "\n",
593
+ "generate_model = GenerateModel(model=model)\n",
594
+ "tf.saved_model.save(generate_model, saved_model_dir, signatures={\n",
595
+ " \"serving_default\": generate_model.transcribe,\n",
596
+ " \"serving_transcribe\": generate_model.transcribe,\n",
597
+ " \"serving_translate\": generate_model.translate,\n",
598
+ " \"serving_transcribe_lang\": generate_model.transcribe_lang,\n",
599
+ "\n",
600
+ "})\n",
601
+ "\n",
602
+ "# Convert the model\n",
603
+ "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
604
+ "converter.target_spec.supported_ops = [\n",
605
+ " tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.\n",
606
+ " tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.\n",
607
+ "]\n",
608
+ "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
609
+ "tflite_model = converter.convert()\n",
610
+ "\n",
611
+ "# Save the model\n",
612
+ "with open(tflite_model_path, 'wb') as f:\n",
613
+ " f.write(tflite_model)"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "source": [
619
+ "pwd"
620
+ ],
621
+ "metadata": {
622
+ "colab": {
623
+ "base_uri": "https://localhost:8080/",
624
+ "height": 35
625
+ },
626
+ "id": "llf-5421rZ-G",
627
+ "outputId": "869a4c96-7f76-4834-d00f-9f4078f4d300"
628
+ },
629
+ "execution_count": null,
630
+ "outputs": [
631
+ {
632
+ "output_type": "execute_result",
633
+ "data": {
634
+ "text/plain": [
635
+ "'/content'"
636
+ ],
637
+ "application/vnd.google.colaboratory.intrinsic+json": {
638
+ "type": "string"
639
+ }
640
+ },
641
+ "metadata": {},
642
+ "execution_count": 2
643
+ }
644
+ ]
645
+ },
646
+ {
647
+ "cell_type": "code",
648
+ "source": [
649
+ "!zip -r /content/tf_whisper_saved.zip /content/tf_whisper_saved/"
650
+ ],
651
+ "metadata": {
652
+ "colab": {
653
+ "base_uri": "https://localhost:8080/"
654
+ },
655
+ "collapsed": true,
656
+ "id": "7pnAWtGZp6MJ",
657
+ "outputId": "42d6c775-1af9-4482-837a-eb3537d5e2c0"
658
+ },
659
+ "execution_count": null,
660
+ "outputs": [
661
+ {
662
+ "output_type": "stream",
663
+ "name": "stdout",
664
+ "text": [
665
+ " adding: content/tf_whisper_saved/ (stored 0%)\n",
666
+ " adding: content/tf_whisper_saved/assets/ (stored 0%)\n",
667
+ " adding: content/tf_whisper_saved/variables/ (stored 0%)\n",
668
+ " adding: content/tf_whisper_saved/variables/variables.data-00000-of-00001 (deflated 41%)\n",
669
+ " adding: content/tf_whisper_saved/variables/variables.index (deflated 79%)\n",
670
+ " adding: content/tf_whisper_saved/fingerprint.pb (stored 0%)\n",
671
+ " adding: content/tf_whisper_saved/keras_metadata.pb (deflated 96%)\n",
672
+ " adding: content/tf_whisper_saved/saved_model.pb (deflated 93%)\n"
673
+ ]
674
+ }
675
+ ]
676
+ },
677
+ {
678
+ "cell_type": "code",
679
+ "execution_count": null,
680
+ "metadata": {
681
+ "id": "u9MustgMU7oI"
682
+ },
683
+ "outputs": [],
684
+ "source": [
685
+ "# loaded model... now with generate!\n",
686
+ "tflite_model_path = 'whisper.tflite'\n",
687
+ "interpreter = tf.lite.Interpreter(tflite_model_path)\n",
688
+ "\n",
689
+ "tflite_generate = interpreter.get_signature_runner('serving_transcribe_lang')\n",
690
+ "lang_token = tf.constant([50286], dtype=tf.int32)\n",
691
+ "generated_ids = tflite_generate(input_features=input_features,lang_token=lang_token)[\"sequences\"]\n",
692
+ "transcription = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]\n",
693
+ "transcription\n",
694
+ "\n",
695
+ "\n"
696
+ ]
697
+ }
698
+ ],
699
+ "metadata": {
700
+ "colab": {
701
+ "machine_shape": "hm",
702
+ "provenance": []
703
+ },
704
+ "kernelspec": {
705
+ "display_name": "Python 3",
706
+ "name": "python3"
707
+ },
708
+ "language_info": {
709
+ "name": "python"
710
+ }
711
+ },
712
+ "nbformat": 4,
713
+ "nbformat_minor": 0
714
+ }
whisper-base.EUROPEAN_UNION.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8fb947274f87da1298f24b7dce39920cebb9f9558f6deb0c6a591a3bbb395bb
3
+ size 94819264
whisper-base.EUROPEAN_UNION.tokens ADDED
@@ -0,0 +1 @@
 
 
1
+ [50259, 50261, 50262, 50265, 50267, 50268, 50269, 50271, 50273, 50274, 50277, 50281, 50283, 50284, 50285, 50286, 50288, 50291, 50292, 50293, 50298, 50301, 50305, 50307, 50343, 50345]