svjack commited on
Commit
340485a
·
verified ·
1 Parent(s): f4983c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py CHANGED
@@ -14,6 +14,71 @@ pip install -r requirements.txt
14
  python app.py
15
  '''
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  import subprocess
18
  import re
19
  from typing import List, Tuple, Optional
@@ -576,6 +641,24 @@ with gr.Blocks(css=css) as demo:
576
  ],
577
  queue = False
578
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
 
580
 
581
  # triggered when we click on image to add new points
@@ -660,4 +743,6 @@ with gr.Blocks(css=css) as demo:
660
  outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn]
661
  )
662
 
 
 
663
  demo.launch(share = True)
 
14
  python app.py
15
  '''
16
 
17
+ from datasets import load_dataset
18
+ vid_ds = load_dataset("svjack/video-dataset-genshin-impact-ep-character-organized")
19
+
20
+ import decord
21
+ import cv2
22
+ import os
23
+
24
+ def save_video_to_mp4(input_video, output_path):
25
+ """
26
+ 将 decord.VideoReader 读取的视频或视频文件保存为 MP4 文件。
27
+
28
+ 参数:
29
+ input_video (str 或 decord.VideoReader): 视频文件路径或 decord.VideoReader 对象。
30
+ output_path (str): 输出 MP4 文件的路径。
31
+ """
32
+ # 如果输入是路径,则创建 VideoReader 对象
33
+ if isinstance(input_video, str):
34
+ vr = decord.VideoReader(input_video)
35
+ elif isinstance(input_video, decord.VideoReader):
36
+ vr = input_video
37
+ else:
38
+ raise ValueError("输入必须是视频文件路径或 decord.VideoReader 对象")
39
+
40
+ # 获取视频的基本信息
41
+ fps = vr.get_avg_fps()
42
+ width = vr[0].shape[1]
43
+ height = vr[0].shape[0]
44
+
45
+ # 创建 VideoWriter 对象
46
+ #fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 编码器
47
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 编码器
48
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
49
+
50
+ # 逐帧读取并写入
51
+ for frame in vr:
52
+ frame = frame.asnumpy() # 将 decord 帧转换为 numpy 数组
53
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # 转换颜色空间
54
+ out.write(frame)
55
+
56
+ # 释放资源
57
+ out.release()
58
+ print(f"视频已保存为 {output_path}")
59
+
60
+ #rd = vid_ds["train"][0]["video"]
61
+ #rd
62
+ #save_video_to_mp4(rd, "output_video.mp4")
63
+
64
+ out_dir = "examples"
65
+ import shutil
66
+ from uuid import uuid1
67
+ from tqdm import tqdm
68
+ if os.path.exists(out_dir):
69
+ shutil.rmtree(out_dir)
70
+
71
+ vid_l = []
72
+ os.makedirs(out_dir, exist_ok = True)
73
+ for exp in tqdm(vid_ds["train"]):
74
+ rd = exp["video"]
75
+ target_path = os.path.join(out_dir, "{}.mp4".format(uuid1()))
76
+ vid_l.append(target_path)
77
+ save_video_to_mp4(rd, target_path)
78
+ if len(vid_l) >= 10:
79
+ break
80
+
81
+
82
  import subprocess
83
  import re
84
  from typing import List, Tuple, Optional
 
641
  ],
642
  queue = False
643
  )
644
+
645
+ video_in.change(
646
+ fn = preprocess_video_in,
647
+ inputs = [video_in],
648
+ outputs = [
649
+ first_frame_path,
650
+ tracking_points, # update Tracking Points in the gr.State([]) object
651
+ trackings_input_label, # update Tracking Labels in the gr.State([]) object
652
+ input_first_frame_image, # hidden component used as ref when clearing points
653
+ points_map, # Image component where we add new tracking points
654
+ video_frames_dir, # Array where frames from video_in are deep stored
655
+ scanned_frames, # Scanned frames by SAM2
656
+ stored_inference_state, # Sam2 inference state
657
+ stored_frame_names, #
658
+ video_in_drawer, # Accordion to hide uploaded video player
659
+ ],
660
+ queue = False
661
+ )
662
 
663
 
664
  # triggered when we click on image to add new points
 
743
  outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn]
744
  )
745
 
746
+ gr.Examples(vid_l, video_in)
747
+
748
  demo.launch(share = True)