lalalalalalalalalala commited on
Commit
bca51b0
·
verified ·
1 Parent(s): e09b4da

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +46 -47
run.py CHANGED
@@ -18,53 +18,52 @@ def load_hf_dataset(dataset_path, auth_token):
18
 
19
  def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
20
  progress_info = []
21
- with tempfile.NamedTemporaryFile(mode='w', delete=False, newline='') as csv_file:
22
- csv_filename = csv_file.name
23
- fieldnames = ['md5', 'caption']
24
- writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
25
- writer.writeheader()
26
-
27
- if video_src:
28
- video = video_src
29
- processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
30
- frames = processor._decode(video)
31
- base64_list = processor.to_base64_list(frames)
32
- debug_image = processor.concatenate(frames)
33
- if not key or not endpoint:
34
- return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image
35
- api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
36
- caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
37
- progress_info.append(f"Using model '{model}' with {len(frames)} frames extracted.")
38
- writer.writerow({'md5': 'single_video', 'caption': caption})
39
- return f"{caption}", "\n".join(progress_info), debug_image
40
- elif video_hf and video_hf_auth:
41
- current_file_path = os.path.abspath(__file__)
42
- current_directory = os.path.dirname(current_file_path)
43
- progress_info.append('Begin processing Hugging Face dataset.')
44
- temp_parquet_file = hf_hub_download(
45
- repo_id=video_hf,
46
- filename='data/' + str(parquet_index).zfill(6) + '.parquet',
47
- repo_type="dataset",
48
- token=video_hf_auth,
49
- )
50
- parquet_file = pq.ParquetFile(temp_parquet_file)
51
- for batch in parquet_file.iter_batches(batch_size=1):
52
- df = batch.to_pandas()
53
- video = df['video'][0]
54
- md5 = hashlib.md5(video).hexdigest()
55
- with tempfile.NamedTemporaryFile(dir=current_directory) as temp_file:
56
- temp_file.write(video)
57
- video_path = temp_file.name
58
- processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
59
- frames = processor._decode(video_path)
60
- base64_list = processor.to_base64_list(frames)
61
- api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
62
- caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
63
- writer.writerow({'md5': md5, 'caption': caption})
64
- progress_info.append(f"Processed video with MD5: {md5}")
65
- return csv_filename, "\n".join(progress_info), None
66
- else:
67
- return "", "No video source selected.", None
68
 
69
  with gr.Blocks() as Core:
70
  with gr.Row(variant="panel"):
 
18
 
19
  def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
20
  progress_info = []
21
+ with tempfile.TemporaryDirectory() as temp_dir:
22
+ csv_filename = os.path.join(temp_dir, str(parquet_index) + '_caption.csv')
23
+ with open(csv_filename, mode='w', newline='') as csv_file:
24
+ fieldnames = ['md5', 'caption']
25
+ writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
26
+ writer.writeheader()
27
+
28
+ if video_src:
29
+ video = video_src
30
+ processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
31
+ frames = processor._decode(video)
32
+ base64_list = processor.to_base64_list(frames)
33
+ debug_image = processor.concatenate(frames)
34
+ if not key or not endpoint:
35
+ return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image
36
+ api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
37
+ caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
38
+ progress_info.append(f"Using model '{model}' with {len(frames)} frames extracted.")
39
+ writer.writerow({'md5': 'single_video', 'caption': caption})
40
+ return f"{caption}", "\n".join(progress_info), debug_image
41
+ elif video_hf and video_hf_auth:
42
+ progress_info.append('Begin processing Hugging Face dataset.')
43
+ temp_parquet_file = hf_hub_download(
44
+ repo_id=video_hf,
45
+ filename='data/' + str(parquet_index).zfill(6) + '.parquet',
46
+ repo_type="dataset",
47
+ token=video_hf_auth,
48
+ )
49
+ parquet_file = pq.ParquetFile(temp_parquet_file)
50
+ for batch in parquet_file.iter_batches(batch_size=1):
51
+ df = batch.to_pandas()
52
+ video = df['video'][0]
53
+ md5 = hashlib.md5(video).hexdigest()
54
+ with tempfile.NamedTemporaryFile(dir=temp_dir) as temp_file:
55
+ temp_file.write(video)
56
+ video_path = temp_file.name
57
+ processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
58
+ frames = processor._decode(video_path)
59
+ base64_list = processor.to_base64_list(frames)
60
+ api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
61
+ caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
62
+ writer.writerow({'md5': md5, 'caption': caption})
63
+ progress_info.append(f"Processed video with MD5: {md5}")
64
+ return csv_filename, "\n".join(progress_info), None
65
+ else:
66
+ return "", "No video source selected.", None
 
67
 
68
  with gr.Blocks() as Core:
69
  with gr.Row(variant="panel"):