yeliudev commited on
Commit
882e9a6
·
verified ·
1 Parent(s): 54a174a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -75
app.py CHANGED
@@ -6,15 +6,13 @@ import os
6
  import random
7
  import time
8
  from functools import partial
9
- from threading import Thread
10
 
11
  import gradio as gr
12
  import nncore
13
- import spaces
14
  import torch
15
  from huggingface_hub import snapshot_download
16
- from transformers import TextIteratorStreamer
17
 
 
18
  from videomind.constants import GROUNDER_PROMPT, PLANNER_PROMPT, VERIFIER_PROMPT
19
  from videomind.dataset.utils import process_vision_info
20
  from videomind.model.builder import build_model
@@ -63,43 +61,6 @@ function init() {
63
  """
64
 
65
 
66
- class CustomStreamer(TextIteratorStreamer):
67
-
68
- def put(self, value):
69
- if len(value.shape) > 1 and value.shape[0] > 1:
70
- raise ValueError('TextStreamer only supports batch size 1')
71
- elif len(value.shape) > 1:
72
- value = value[0]
73
-
74
- if self.skip_prompt and self.next_tokens_are_prompt:
75
- self.next_tokens_are_prompt = False
76
- return
77
-
78
- self.token_cache.extend(value.tolist())
79
-
80
- # force skipping eos token
81
- if self.token_cache[-1] == self.tokenizer.eos_token_id:
82
- self.token_cache = self.token_cache[:-1]
83
-
84
- text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
85
-
86
- # cache decoded text for future use
87
- self.text_cache = text
88
-
89
- if text.endswith('\n'):
90
- printable_text = text[self.print_len:]
91
- self.token_cache = []
92
- self.print_len = 0
93
- elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
94
- printable_text = text[self.print_len:]
95
- self.print_len += len(printable_text)
96
- else:
97
- printable_text = text[self.print_len:text.rfind(' ') + 1]
98
- self.print_len += len(printable_text)
99
-
100
- self.on_finalized_text(printable_text)
101
-
102
-
103
  def seconds_to_hms(seconds):
104
  hours, remainder = divmod(round(seconds), 3600)
105
  minutes, seconds = divmod(remainder, 60)
@@ -128,7 +89,7 @@ def reset_components():
128
 
129
 
130
  @spaces.GPU
131
- def main(video, prompt, role, temperature, max_new_tokens, model, processor, streamer, device):
132
  history = []
133
 
134
  if not video:
@@ -204,9 +165,8 @@ def main(video, prompt, role, temperature, max_new_tokens, model, processor, str
204
  model.base_model.enable_adapter_layers()
205
  model.set_adapter('planner')
206
 
207
- generation_kwargs = dict(
208
  **data,
209
- streamer=streamer,
210
  do_sample=temperature > 0,
211
  temperature=temperature if temperature > 0 else None,
212
  top_p=None,
@@ -214,15 +174,18 @@ def main(video, prompt, role, temperature, max_new_tokens, model, processor, str
214
  repetition_penalty=None,
215
  max_new_tokens=max_new_tokens)
216
 
217
- t = Thread(target=model.generate, kwargs=generation_kwargs)
218
- t.start()
 
 
 
219
 
220
- skipped = False
221
- for i, text in enumerate(streamer):
222
- if text and not skipped:
223
  history[-1]['content'] = history[-1]['content'].rstrip('.')
224
- skipped = True
225
- history[-1]['content'] += text
 
226
  yield history
227
 
228
  elapsed_time = round(time.perf_counter() - start_time, 1)
@@ -230,7 +193,7 @@ def main(video, prompt, role, temperature, max_new_tokens, model, processor, str
230
  yield history
231
 
232
  try:
233
- parsed = json.loads(streamer.text_cache)
234
  action = parsed[0] if isinstance(parsed, list) else parsed
235
  if action['type'].lower() == 'grounder' and action['value']:
236
  query = action['value']
@@ -301,9 +264,8 @@ def main(video, prompt, role, temperature, max_new_tokens, model, processor, str
301
  model.base_model.enable_adapter_layers()
302
  model.set_adapter('grounder')
303
 
304
- generation_kwargs = dict(
305
  **data,
306
- streamer=streamer,
307
  do_sample=temperature > 0,
308
  temperature=temperature if temperature > 0 else None,
309
  top_p=None,
@@ -311,15 +273,18 @@ def main(video, prompt, role, temperature, max_new_tokens, model, processor, str
311
  repetition_penalty=None,
312
  max_new_tokens=max_new_tokens)
313
 
314
- t = Thread(target=model.generate, kwargs=generation_kwargs)
315
- t.start()
 
 
 
316
 
317
- skipped = False
318
- for i, text in enumerate(streamer):
319
- if text and not skipped:
320
  history[-1]['content'] = history[-1]['content'].rstrip('.')
321
- skipped = True
322
- history[-1]['content'] += text
 
323
  yield history
324
 
325
  elapsed_time = round(time.perf_counter() - start_time, 1)
@@ -520,9 +485,8 @@ def main(video, prompt, role, temperature, max_new_tokens, model, processor, str
520
  data = data.to(device)
521
 
522
  with model.disable_adapter():
523
- generation_kwargs = dict(
524
  **data,
525
- streamer=streamer,
526
  do_sample=temperature > 0,
527
  temperature=temperature if temperature > 0 else None,
528
  top_p=None,
@@ -530,25 +494,28 @@ def main(video, prompt, role, temperature, max_new_tokens, model, processor, str
530
  repetition_penalty=None,
531
  max_new_tokens=max_new_tokens)
532
 
533
- t = Thread(target=model.generate, kwargs=generation_kwargs)
534
- t.start()
 
 
 
535
 
536
- skipped = False
537
- for i, text in enumerate(streamer):
538
- if text and not skipped:
539
- history[-1]['content'] = history[-1]['content'].rstrip('.')
540
- skipped = True
541
  history[-1]['content'] += text
542
- yield history
 
 
543
 
544
  elapsed_time = round(time.perf_counter() - start_time, 1)
545
  history[-1]['metadata']['title'] += f' ({elapsed_time} seconds)'
546
  yield history
547
 
548
  if 'gnd' in role and do_grounding:
549
- response = f'After zooming in and analyzing the target moment, I finalize my answer: <span style="color:green">{streamer.text_cache}</span>'
550
  else:
551
- response = f'After watching the whole video, my answer is: <span style="color:green">{streamer.text_cache}</span>'
552
 
553
  history.append({'role': 'assistant', 'content': ''})
554
  for i, text in enumerate(response.split(' ')):
@@ -572,11 +539,9 @@ if __name__ == '__main__':
572
  print('Initializing role *verifier*')
573
  model.load_adapter(nncore.join(MODEL, 'verifier'), adapter_name='verifier')
574
 
575
- streamer = CustomStreamer(processor.tokenizer, skip_prompt=True)
576
-
577
  device = next(model.parameters()).device
578
 
579
- main = partial(main, model=model, processor=processor, streamer=streamer, device=device)
580
 
581
  path = os.path.dirname(os.path.realpath(__file__))
582
 
 
6
  import random
7
  import time
8
  from functools import partial
 
9
 
10
  import gradio as gr
11
  import nncore
 
12
  import torch
13
  from huggingface_hub import snapshot_download
 
14
 
15
+ import spaces
16
  from videomind.constants import GROUNDER_PROMPT, PLANNER_PROMPT, VERIFIER_PROMPT
17
  from videomind.dataset.utils import process_vision_info
18
  from videomind.model.builder import build_model
 
61
  """
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def seconds_to_hms(seconds):
65
  hours, remainder = divmod(round(seconds), 3600)
66
  minutes, seconds = divmod(remainder, 60)
 
89
 
90
 
91
  @spaces.GPU
92
+ def main(video, prompt, role, temperature, max_new_tokens, model, processor, device):
93
  history = []
94
 
95
  if not video:
 
165
  model.base_model.enable_adapter_layers()
166
  model.set_adapter('planner')
167
 
168
+ output_ids = model.generate(
169
  **data,
 
170
  do_sample=temperature > 0,
171
  temperature=temperature if temperature > 0 else None,
172
  top_p=None,
 
174
  repetition_penalty=None,
175
  max_new_tokens=max_new_tokens)
176
 
177
+ assert data.input_ids.size(0) == output_ids.size(0) == 1
178
+ output_ids = output_ids[0, data.input_ids.size(1):]
179
+ if output_ids[-1] == processor.tokenizer.eos_token_id:
180
+ output_ids = output_ids[:-1]
181
+ response = processor.decode(output_ids, clean_up_tokenization_spaces=False)
182
 
183
+ for i, text in enumerate(response.split(' ')):
184
+ if i == 0:
 
185
  history[-1]['content'] = history[-1]['content'].rstrip('.')
186
+ history[-1]['content'] += text
187
+ else:
188
+ history[-1]['content'] += ' ' + text
189
  yield history
190
 
191
  elapsed_time = round(time.perf_counter() - start_time, 1)
 
193
  yield history
194
 
195
  try:
196
+ parsed = json.loads(response)
197
  action = parsed[0] if isinstance(parsed, list) else parsed
198
  if action['type'].lower() == 'grounder' and action['value']:
199
  query = action['value']
 
264
  model.base_model.enable_adapter_layers()
265
  model.set_adapter('grounder')
266
 
267
+ output_ids = model.generate(
268
  **data,
 
269
  do_sample=temperature > 0,
270
  temperature=temperature if temperature > 0 else None,
271
  top_p=None,
 
273
  repetition_penalty=None,
274
  max_new_tokens=max_new_tokens)
275
 
276
+ assert data.input_ids.size(0) == output_ids.size(0) == 1
277
+ output_ids = output_ids[0, data.input_ids.size(1):]
278
+ if output_ids[-1] == processor.tokenizer.eos_token_id:
279
+ output_ids = output_ids[:-1]
280
+ response = processor.decode(output_ids, clean_up_tokenization_spaces=False)
281
 
282
+ for i, text in enumerate(response.split(' ')):
283
+ if i == 0:
 
284
  history[-1]['content'] = history[-1]['content'].rstrip('.')
285
+ history[-1]['content'] += text
286
+ else:
287
+ history[-1]['content'] += ' ' + text
288
  yield history
289
 
290
  elapsed_time = round(time.perf_counter() - start_time, 1)
 
485
  data = data.to(device)
486
 
487
  with model.disable_adapter():
488
+ output_ids = model.generate(
489
  **data,
 
490
  do_sample=temperature > 0,
491
  temperature=temperature if temperature > 0 else None,
492
  top_p=None,
 
494
  repetition_penalty=None,
495
  max_new_tokens=max_new_tokens)
496
 
497
+ assert data.input_ids.size(0) == output_ids.size(0) == 1
498
+ output_ids = output_ids[0, data.input_ids.size(1):]
499
+ if output_ids[-1] == processor.tokenizer.eos_token_id:
500
+ output_ids = output_ids[:-1]
501
+ response = processor.decode(output_ids, clean_up_tokenization_spaces=False)
502
 
503
+ for i, text in enumerate(response.split(' ')):
504
+ if i == 0:
505
+ history[-1]['content'] = history[-1]['content'].rstrip('.')
 
 
506
  history[-1]['content'] += text
507
+ else:
508
+ history[-1]['content'] += ' ' + text
509
+ yield history
510
 
511
  elapsed_time = round(time.perf_counter() - start_time, 1)
512
  history[-1]['metadata']['title'] += f' ({elapsed_time} seconds)'
513
  yield history
514
 
515
  if 'gnd' in role and do_grounding:
516
+ response = f'After zooming in and analyzing the target moment, I finalize my answer: <span style="color:green">{response}</span>'
517
  else:
518
+ response = f'After watching the whole video, my answer is: <span style="color:green">{response}</span>'
519
 
520
  history.append({'role': 'assistant', 'content': ''})
521
  for i, text in enumerate(response.split(' ')):
 
539
  print('Initializing role *verifier*')
540
  model.load_adapter(nncore.join(MODEL, 'verifier'), adapter_name='verifier')
541
 
 
 
542
  device = next(model.parameters()).device
543
 
544
+ main = partial(main, model=model, processor=processor, device=device)
545
 
546
  path = os.path.dirname(os.path.realpath(__file__))
547