Duplicate from tomiwa1a/video-search
Browse filesCo-authored-by: Tomiwa Ademidun <[email protected]>
- .gitattributes +32 -0
- README.md +22 -0
- create_handler.ipynb +289 -0
- handler.py +260 -0
- requirements.txt +3 -0
- sample1.flac +0 -0
.gitattributes
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
24 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: gpl-3.0
|
3 |
+
tags:
|
4 |
+
- audio
|
5 |
+
- automatic-speech-recognition
|
6 |
+
- endpoints-template
|
7 |
+
library_name: generic
|
8 |
+
inference: false
|
9 |
+
duplicated_from: tomiwa1a/video-search
|
10 |
+
---
|
11 |
+
|
12 |
+
# Video Search
|
13 |
+
|
14 |
+
This project contains 3 different models that can be used for searching videos.
|
15 |
+
|
16 |
+
1. Whisper to convert mp3 files to audio
|
17 |
+
2. BART Sentence Transformer to generate vector embeddings from text
|
18 |
+
3. BART LFQA to generate long form answers given a context
|
19 |
+
|
20 |
+
For more context, see: [Atlas: Find Anything on Youtube](https://atila.ca/blog/tomiwa/atlas)
|
21 |
+
|
22 |
+
Inspired by [philschmid/openai-whisper-endpoint](https://huggingface.co/philschmid/openai-whisper-endpoint)
|
create_handler.ipynb
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"## 1. Setup & Installation"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 1,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [
|
15 |
+
{
|
16 |
+
"name": "stdout",
|
17 |
+
"output_type": "stream",
|
18 |
+
"text": [
|
19 |
+
"Overwriting requirements.txt\n"
|
20 |
+
]
|
21 |
+
}
|
22 |
+
],
|
23 |
+
"source": [
|
24 |
+
"%%writefile requirements.txt\n",
|
25 |
+
"git+https://github.com/openai/whisper.git@8cf36f3508c9acd341a45eb2364239a3d81458b9"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": null,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"!pip install -r requirements.txt --upgrade"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"metadata": {},
|
40 |
+
"source": [
|
41 |
+
"## 2. Test model"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": 3,
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [
|
49 |
+
{
|
50 |
+
"name": "stdout",
|
51 |
+
"output_type": "stream",
|
52 |
+
"text": [
|
53 |
+
"--2022-09-23 20:32:18-- https://cdn-media.huggingface.co/speech_samples/sample1.flac\n",
|
54 |
+
"Resolving cdn-media.huggingface.co (cdn-media.huggingface.co)... 13.32.151.62, 13.32.151.23, 13.32.151.60, ...\n",
|
55 |
+
"Connecting to cdn-media.huggingface.co (cdn-media.huggingface.co)|13.32.151.62|:443... connected.\n",
|
56 |
+
"HTTP request sent, awaiting response... 200 OK\n",
|
57 |
+
"Length: 282378 (276K) [audio/flac]\n",
|
58 |
+
"Saving to: ‘sample1.flac’\n",
|
59 |
+
"\n",
|
60 |
+
"sample1.flac 100%[===================>] 275.76K --.-KB/s in 0.003s \n",
|
61 |
+
"\n",
|
62 |
+
"2022-09-23 20:32:18 (78.7 MB/s) - ‘sample1.flac’ saved [282378/282378]\n",
|
63 |
+
"\n"
|
64 |
+
]
|
65 |
+
}
|
66 |
+
],
|
67 |
+
"source": [
|
68 |
+
"!wget https://cdn-media.huggingface.co/speech_samples/sample1.flac"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"execution_count": 9,
|
74 |
+
"metadata": {},
|
75 |
+
"outputs": [
|
76 |
+
{
|
77 |
+
"name": "stderr",
|
78 |
+
"output_type": "stream",
|
79 |
+
"text": [
|
80 |
+
"100%|█████████████████████████████████████| 2.87G/2.87G [01:11<00:00, 42.9MiB/s]\n"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"name": "stdout",
|
85 |
+
"output_type": "stream",
|
86 |
+
"text": [
|
87 |
+
"Detected language: english\n",
|
88 |
+
" going along slushy country roads and speaking to damp audiences in drafty school rooms day after day for a fortnight. he'll have to put in an appearance at some place of worship on sunday morning and he can come to us immediately afterwards.\n"
|
89 |
+
]
|
90 |
+
}
|
91 |
+
],
|
92 |
+
"source": [
|
93 |
+
"import whisper\n",
|
94 |
+
"\n",
|
95 |
+
"model = whisper.load_model(\"large\")\n",
|
96 |
+
"result = model.transcribe(\"sample1.flac\")\n",
|
97 |
+
"print(result[\"text\"])"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "markdown",
|
102 |
+
"metadata": {},
|
103 |
+
"source": [
|
104 |
+
"## 3. Create Custom Handler for Inference Endpoints\n"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": 5,
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [
|
112 |
+
{
|
113 |
+
"name": "stdout",
|
114 |
+
"output_type": "stream",
|
115 |
+
"text": [
|
116 |
+
"Overwriting handler.py\n"
|
117 |
+
]
|
118 |
+
}
|
119 |
+
],
|
120 |
+
"source": [
|
121 |
+
"%%writefile handler.py\n",
|
122 |
+
"from typing import Dict\n",
|
123 |
+
"from transformers.pipelines.audio_utils import ffmpeg_read\n",
|
124 |
+
"import whisper\n",
|
125 |
+
"import torch\n",
|
126 |
+
"\n",
|
127 |
+
"SAMPLE_RATE = 16000\n",
|
128 |
+
"\n",
|
129 |
+
"\n",
|
130 |
+
"\n",
|
131 |
+
"class EndpointHandler():\n",
|
132 |
+
" def __init__(self, path=\"\"):\n",
|
133 |
+
" # load the model\n",
|
134 |
+
" self.model = whisper.load_model(\"medium\")\n",
|
135 |
+
"\n",
|
136 |
+
"\n",
|
137 |
+
" def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:\n",
|
138 |
+
" \"\"\"\n",
|
139 |
+
" Args:\n",
|
140 |
+
" data (:obj:):\n",
|
141 |
+
" includes the deserialized audio file as bytes\n",
|
142 |
+
" Return:\n",
|
143 |
+
" A :obj:`dict`:. base64 encoded image\n",
|
144 |
+
" \"\"\"\n",
|
145 |
+
" # process input\n",
|
146 |
+
" inputs = data.pop(\"inputs\", data)\n",
|
147 |
+
" audio_nparray = ffmpeg_read(inputs, SAMPLE_RATE)\n",
|
148 |
+
" audio_tensor= torch.from_numpy(audio_nparray)\n",
|
149 |
+
" \n",
|
150 |
+
" # run inference pipeline\n",
|
151 |
+
" result = self.model.transcribe(audio_nparray)\n",
|
152 |
+
"\n",
|
153 |
+
" # postprocess the prediction\n",
|
154 |
+
" return {\"text\": result[\"text\"]}"
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"cell_type": "markdown",
|
159 |
+
"metadata": {},
|
160 |
+
"source": [
|
161 |
+
"test custom pipeline"
|
162 |
+
]
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"cell_type": "code",
|
166 |
+
"execution_count": 6,
|
167 |
+
"metadata": {},
|
168 |
+
"outputs": [],
|
169 |
+
"source": [
|
170 |
+
"from handler import EndpointHandler\n",
|
171 |
+
"\n",
|
172 |
+
"# init handler\n",
|
173 |
+
"my_handler = EndpointHandler(path=\".\")"
|
174 |
+
]
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"cell_type": "code",
|
178 |
+
"execution_count": null,
|
179 |
+
"metadata": {},
|
180 |
+
"outputs": [
|
181 |
+
{
|
182 |
+
"name": "stderr",
|
183 |
+
"output_type": "stream",
|
184 |
+
"text": [
|
185 |
+
"/home/ubuntu/endpoints/openai-whisper-endpoint/handler.py:27: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)\n",
|
186 |
+
" audio_tensor= torch.from_numpy(audio_nparray)\n"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"name": "stdout",
|
191 |
+
"output_type": "stream",
|
192 |
+
"text": [
|
193 |
+
"Detected language: english\n"
|
194 |
+
]
|
195 |
+
}
|
196 |
+
],
|
197 |
+
"source": [
|
198 |
+
"import base64\n",
|
199 |
+
"from PIL import Image\n",
|
200 |
+
"from io import BytesIO\n",
|
201 |
+
"import json\n",
|
202 |
+
"\n",
|
203 |
+
"# file reader\n",
|
204 |
+
"with open(\"sample1.flac\", \"rb\") as f:\n",
|
205 |
+
" request = {\"inputs\": f.read()}\n",
|
206 |
+
"\n",
|
207 |
+
"\n",
|
208 |
+
"# test the handler\n",
|
209 |
+
"pred = my_handler(request)"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "code",
|
214 |
+
"execution_count": 3,
|
215 |
+
"metadata": {},
|
216 |
+
"outputs": [
|
217 |
+
{
|
218 |
+
"data": {
|
219 |
+
"text/plain": [
|
220 |
+
"{'transcription': \" going along slushy country roads and speaking to damp audiences in draughty school rooms day after day for a fortnight. He'll have to put in an appearance at some place of worship on Sunday morning, and he can come to us immediately afterwards.\"}"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
"execution_count": 3,
|
224 |
+
"metadata": {},
|
225 |
+
"output_type": "execute_result"
|
226 |
+
}
|
227 |
+
],
|
228 |
+
"source": [
|
229 |
+
"pred"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"execution_count": 4,
|
235 |
+
"metadata": {},
|
236 |
+
"outputs": [
|
237 |
+
{
|
238 |
+
"data": {
|
239 |
+
"text/plain": [
|
240 |
+
"'{\"transcription\": \" going along slushy country roads and speaking to damp audiences in draughty school rooms day after day for a fortnight. He\\'ll have to put in an appearance at some place of worship on Sunday morning, and he can come to us immediately afterwards.\"}'"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
"execution_count": 4,
|
244 |
+
"metadata": {},
|
245 |
+
"output_type": "execute_result"
|
246 |
+
}
|
247 |
+
],
|
248 |
+
"source": [
|
249 |
+
"import json\n",
|
250 |
+
"\n",
|
251 |
+
"json.dumps({'transcription': \" going along slushy country roads and speaking to damp audiences in draughty school rooms day after day for a fortnight. He'll have to put in an appearance at some place of worship on Sunday morning, and he can come to us immediately afterwards.\"})"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"cell_type": "code",
|
256 |
+
"execution_count": null,
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [],
|
259 |
+
"source": []
|
260 |
+
}
|
261 |
+
],
|
262 |
+
"metadata": {
|
263 |
+
"kernelspec": {
|
264 |
+
"display_name": "Python 3.9.13 ('dev': conda)",
|
265 |
+
"language": "python",
|
266 |
+
"name": "python3"
|
267 |
+
},
|
268 |
+
"language_info": {
|
269 |
+
"codemirror_mode": {
|
270 |
+
"name": "ipython",
|
271 |
+
"version": 3
|
272 |
+
},
|
273 |
+
"file_extension": ".py",
|
274 |
+
"mimetype": "text/x-python",
|
275 |
+
"name": "python",
|
276 |
+
"nbconvert_exporter": "python",
|
277 |
+
"pygments_lexer": "ipython3",
|
278 |
+
"version": "3.9.13"
|
279 |
+
},
|
280 |
+
"orig_nbformat": 4,
|
281 |
+
"vscode": {
|
282 |
+
"interpreter": {
|
283 |
+
"hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc"
|
284 |
+
}
|
285 |
+
}
|
286 |
+
},
|
287 |
+
"nbformat": 4,
|
288 |
+
"nbformat_minor": 2
|
289 |
+
}
|
handler.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://huggingface.co/tomiwa1a/video-search
|
3 |
+
"""
|
4 |
+
from typing import Dict
|
5 |
+
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from tqdm import tqdm
|
8 |
+
import whisper
|
9 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
10 |
+
import torch
|
11 |
+
import pytube
|
12 |
+
import time
|
13 |
+
|
14 |
+
|
15 |
+
class EndpointHandler():
|
16 |
+
# load the model
|
17 |
+
WHISPER_MODEL_NAME = "tiny.en"
|
18 |
+
SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
|
19 |
+
QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
|
20 |
+
SUMMARIZER_MODEL_NAME = "philschmid/bart-large-cnn-samsum"
|
21 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
22 |
+
device_number = 0 if torch.cuda.is_available() else -1
|
23 |
+
|
24 |
+
def __init__(self, path=""):
|
25 |
+
|
26 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
+
device_number = 0 if torch.cuda.is_available() else -1
|
28 |
+
print(f'whisper and question_answer_model will use: {device}')
|
29 |
+
print(f'whisper and question_answer_model will use device_number: {device_number}')
|
30 |
+
|
31 |
+
t0 = time.time()
|
32 |
+
self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device_number)
|
33 |
+
t1 = time.time()
|
34 |
+
|
35 |
+
total = t1 - t0
|
36 |
+
print(f'Finished loading whisper_model in {total} seconds')
|
37 |
+
|
38 |
+
t0 = time.time()
|
39 |
+
self.sentence_transformer_model = SentenceTransformer(self.SENTENCE_TRANSFORMER_MODEL_NAME)
|
40 |
+
t1 = time.time()
|
41 |
+
|
42 |
+
total = t1 - t0
|
43 |
+
print(f'Finished loading sentence_transformer_model in {total} seconds')
|
44 |
+
|
45 |
+
t0 = time.time()
|
46 |
+
self.summarizer = pipeline("summarization", model=self.SUMMARIZER_MODEL_NAME, device=device_number)
|
47 |
+
t1 = time.time()
|
48 |
+
|
49 |
+
total = t1 - t0
|
50 |
+
print(f'Finished loading summarizer in {total} seconds')
|
51 |
+
|
52 |
+
self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME)
|
53 |
+
t0 = time.time()
|
54 |
+
self.question_answer_model = AutoModelForSeq2SeqLM.from_pretrained \
|
55 |
+
(self.QUESTION_ANSWER_MODEL_NAME).to(device_number)
|
56 |
+
t1 = time.time()
|
57 |
+
total = t1 - t0
|
58 |
+
print(f'Finished loading question_answer_model in {total} seconds')
|
59 |
+
|
60 |
+
def __call__(self, data: Dict[str, str]) -> Dict:
|
61 |
+
"""
|
62 |
+
Args:
|
63 |
+
data (:obj:):
|
64 |
+
includes the URL to video for transcription
|
65 |
+
Return:
|
66 |
+
A :obj:`dict`:. transcribed dict
|
67 |
+
"""
|
68 |
+
# process input
|
69 |
+
print('data', data)
|
70 |
+
|
71 |
+
if "inputs" not in data:
|
72 |
+
raise Exception(f"data is missing 'inputs' key which EndpointHandler expects. Received: {data}"
|
73 |
+
f" See: https://huggingface.co/docs/inference-endpoints/guides/custom_handler#2-create-endpointhandler-cp")
|
74 |
+
video_url = data.pop("video_url", None)
|
75 |
+
query = data.pop("query", None)
|
76 |
+
long_form_answer = data.pop("long_form_answer", None)
|
77 |
+
summarize = data.pop("summarize", False)
|
78 |
+
encoded_segments = {}
|
79 |
+
if video_url:
|
80 |
+
video_with_transcript = self.transcribe_video(video_url)
|
81 |
+
video_with_transcript['transcript']['transcription_source'] = f"whisper_{self.WHISPER_MODEL_NAME}"
|
82 |
+
encode_transcript = data.pop("encode_transcript", True)
|
83 |
+
if encode_transcript:
|
84 |
+
encoded_segments = self.combine_transcripts(video_with_transcript)
|
85 |
+
encoded_segments = {
|
86 |
+
"encoded_segments": self.encode_sentences(encoded_segments)
|
87 |
+
}
|
88 |
+
return {
|
89 |
+
**video_with_transcript,
|
90 |
+
**encoded_segments
|
91 |
+
}
|
92 |
+
elif summarize:
|
93 |
+
summary = self.summarize_video(data["segments"])
|
94 |
+
return {"summary": summary}
|
95 |
+
elif query:
|
96 |
+
if long_form_answer:
|
97 |
+
context = data.pop("context", None)
|
98 |
+
answer = self.generate_answer(query, context)
|
99 |
+
response = {
|
100 |
+
"answer": answer
|
101 |
+
}
|
102 |
+
|
103 |
+
return response
|
104 |
+
else:
|
105 |
+
query = [{"text": query, "id": ""}] if isinstance(query, str) else query
|
106 |
+
encoded_segments = self.encode_sentences(query)
|
107 |
+
|
108 |
+
response = {
|
109 |
+
"encoded_segments": encoded_segments
|
110 |
+
}
|
111 |
+
|
112 |
+
return response
|
113 |
+
|
114 |
+
else:
|
115 |
+
return {
|
116 |
+
"error": "'video_url' or 'query' must be provided"
|
117 |
+
}
|
118 |
+
|
119 |
+
def transcribe_video(self, video_url):
|
120 |
+
decode_options = {
|
121 |
+
# Set language to None to support multilingual,
|
122 |
+
# but it will take longer to process while it detects the language.
|
123 |
+
# Realized this by running in verbose mode and seeing how much time
|
124 |
+
# was spent on the decoding language step
|
125 |
+
"language": "en",
|
126 |
+
"verbose": True
|
127 |
+
}
|
128 |
+
yt = pytube.YouTube(video_url)
|
129 |
+
video_info = {
|
130 |
+
'id': yt.video_id,
|
131 |
+
'thumbnail': yt.thumbnail_url,
|
132 |
+
'title': yt.title,
|
133 |
+
'views': yt.views,
|
134 |
+
'length': yt.length,
|
135 |
+
# Althhough, this might seem redundant since we already have id
|
136 |
+
# but it allows the link to the video be accessed in 1-click in the API response
|
137 |
+
'url': f"https://www.youtube.com/watch?v={yt.video_id}"
|
138 |
+
}
|
139 |
+
stream = yt.streams.filter(only_audio=True)[0]
|
140 |
+
path_to_audio = f"{yt.video_id}.mp3"
|
141 |
+
stream.download(filename=path_to_audio)
|
142 |
+
t0 = time.time()
|
143 |
+
transcript = self.whisper_model.transcribe(path_to_audio, **decode_options)
|
144 |
+
t1 = time.time()
|
145 |
+
for segment in transcript['segments']:
|
146 |
+
# Remove the tokens array, it makes the response too verbose
|
147 |
+
segment.pop('tokens', None)
|
148 |
+
|
149 |
+
total = t1 - t0
|
150 |
+
print(f'Finished transcription in {total} seconds')
|
151 |
+
|
152 |
+
# postprocess the prediction
|
153 |
+
return {"transcript": transcript, 'video': video_info}
|
154 |
+
|
155 |
+
def encode_sentences(self, transcripts, batch_size=64):
|
156 |
+
"""
|
157 |
+
Encoding all of our segments at once or storing them locally would require too much compute or memory.
|
158 |
+
So we do it in batches of 64
|
159 |
+
:param transcripts:
|
160 |
+
:param batch_size:
|
161 |
+
:return:
|
162 |
+
"""
|
163 |
+
# loop through in batches of 64
|
164 |
+
all_batches = []
|
165 |
+
for i in tqdm(range(0, len(transcripts), batch_size)):
|
166 |
+
# find end position of batch (for when we hit end of data)
|
167 |
+
i_end = min(len(transcripts), i + batch_size)
|
168 |
+
# extract the metadata like text, start/end positions, etc
|
169 |
+
batch_meta = [{
|
170 |
+
**row
|
171 |
+
} for row in transcripts[i:i_end]]
|
172 |
+
# extract only text to be encoded by embedding model
|
173 |
+
batch_text = [
|
174 |
+
row['text'] for row in batch_meta
|
175 |
+
]
|
176 |
+
# create the embedding vectors
|
177 |
+
batch_vectors = self.sentence_transformer_model.encode(batch_text).tolist()
|
178 |
+
|
179 |
+
batch_details = [
|
180 |
+
{
|
181 |
+
**batch_meta[x],
|
182 |
+
'vectors': batch_vectors[x]
|
183 |
+
} for x in range(0, len(batch_meta))
|
184 |
+
]
|
185 |
+
all_batches.extend(batch_details)
|
186 |
+
|
187 |
+
return all_batches
|
188 |
+
|
189 |
+
def summarize_video(self, segments):
|
190 |
+
for index, segment in enumerate(segments):
|
191 |
+
segment['summary'] = self.summarizer(segment['text'])
|
192 |
+
segment['summary'] = segment['summary'][0]['summary_text']
|
193 |
+
print('index', index)
|
194 |
+
print('length', segment['length'])
|
195 |
+
print('text', segment['text'])
|
196 |
+
print('summary', segment['summary'])
|
197 |
+
|
198 |
+
return segments
|
199 |
+
|
200 |
+
def generate_answer(self, query, documents):
|
201 |
+
|
202 |
+
# concatenate question and support documents into BART input
|
203 |
+
conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
|
204 |
+
query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
|
205 |
+
|
206 |
+
model_input = self.question_answer_tokenizer(query_and_docs, truncation=False, padding=True,
|
207 |
+
return_tensors="pt")
|
208 |
+
|
209 |
+
generated_answers_encoded = self.question_answer_model.generate(
|
210 |
+
input_ids=model_input["input_ids"].to(self.device),
|
211 |
+
attention_mask=model_input["attention_mask"].to(self.device),
|
212 |
+
min_length=64,
|
213 |
+
max_length=256,
|
214 |
+
do_sample=False,
|
215 |
+
early_stopping=True,
|
216 |
+
num_beams=8,
|
217 |
+
temperature=1.0,
|
218 |
+
top_k=None,
|
219 |
+
top_p=None,
|
220 |
+
eos_token_id=self.question_answer_tokenizer.eos_token_id,
|
221 |
+
no_repeat_ngram_size=3,
|
222 |
+
num_return_sequences=1)
|
223 |
+
answer = self.question_answer_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
|
224 |
+
clean_up_tokenization_spaces=True)
|
225 |
+
return answer
|
226 |
+
|
227 |
+
@staticmethod
|
228 |
+
def combine_transcripts(video, window=6, stride=3):
|
229 |
+
"""
|
230 |
+
|
231 |
+
:param video:
|
232 |
+
:param window: number of sentences to combine
|
233 |
+
:param stride: number of sentences to 'stride' over, used to create overlap
|
234 |
+
:return:
|
235 |
+
"""
|
236 |
+
new_transcript_segments = []
|
237 |
+
|
238 |
+
video_info = video['video']
|
239 |
+
transcript_segments = video['transcript']['segments']
|
240 |
+
for i in tqdm(range(0, len(transcript_segments), stride)):
|
241 |
+
i_end = min(len(transcript_segments), i + window)
|
242 |
+
text = ' '.join(transcript['text']
|
243 |
+
for transcript in
|
244 |
+
transcript_segments[i:i_end])
|
245 |
+
# TODO: Should int (float to seconds) conversion happen at the API level?
|
246 |
+
start = int(transcript_segments[i]['start'])
|
247 |
+
end = int(transcript_segments[i]['end'])
|
248 |
+
new_transcript_segments.append({
|
249 |
+
**video_info,
|
250 |
+
**{
|
251 |
+
'start': start,
|
252 |
+
'end': end,
|
253 |
+
'title': video_info['title'],
|
254 |
+
'text': text,
|
255 |
+
'id': f"{video_info['id']}-t{start}",
|
256 |
+
'url': f"https://youtu.be/{video_info['id']}?t={start}",
|
257 |
+
'video_id': video_info['id'],
|
258 |
+
}
|
259 |
+
})
|
260 |
+
return new_transcript_segments
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/openai/whisper.git@8cf36f3508c9acd341a45eb2364239a3d81458b9
|
2 |
+
pytube==12.1.2
|
3 |
+
sentence-transformers==2.2.2
|
sample1.flac
ADDED
Binary file (282 kB). View file
|
|