Muhammad Taqi Raza commited on
Commit
8eeb1c7
·
1 Parent(s): 7f5f48b

correct infer_gradual

Browse files
gradio_app.py CHANGED
@@ -166,7 +166,7 @@ with demo:
166
  depth_guidance_input = gr.Number(value=1.0, label="Depth Guidance")
167
  window_input = gr.Number(value=64, label="Window Size")
168
  overlap_input = gr.Number(value=25, label="Overlap")
169
- maxres_input = gr.Number(value=1024, label="Max Resolution")
170
  sample_size = gr.Textbox(label="Sample Size (height, width)", placeholder="e.g., 384, 672", value="384, 672")
171
  seed_input = gr.Number(value=43, label="Seed")
172
  height = gr.Number(value=576, label="Height")
 
166
  depth_guidance_input = gr.Number(value=1.0, label="Depth Guidance")
167
  window_input = gr.Number(value=64, label="Window Size")
168
  overlap_input = gr.Number(value=25, label="Overlap")
169
+ maxres_input = gr.Number(value=1920, label="Max Resolution")
170
  sample_size = gr.Textbox(label="Sample Size (height, width)", placeholder="e.g., 384, 672", value="384, 672")
171
  seed_input = gr.Number(value=43, label="Seed")
172
  height = gr.Number(value=576, label="Height")
inference/v2v_data/demo.py CHANGED
@@ -111,14 +111,14 @@ class GetAnchorVideos:
111
 
112
  def infer_gradual(self, opts):
113
  frames = read_video_frames(
114
- opts.video_path, opts.video_length, opts.stride, opts.max_res, height = opts.height, width = opts.width
115
  )
116
  vr = VideoReader(opts.video_path, ctx=cpu(0))
117
  frame_shape = vr[0].shape # (H, W, 3)
118
  ori_resolution = frame_shape[:2]
119
  print(f"==> original video shape: {frame_shape}")
120
- # target_resolution = get_center_crop_resolution(original_resoultion = ori_resolution, height = opts.height, width = opts.width)
121
- # print(f"==> target video shape resized: {target_resolution}")
122
 
123
  prompt = self.get_caption(opts, opts.video_path)
124
  depths = self.depth_estimater.infer(
@@ -138,8 +138,8 @@ class GetAnchorVideos:
138
  print(f"==> opts video length: {opts.video_length}")
139
  assert frames.shape[0] == opts.video_length
140
 
141
- # depths = center_crop_to_ratio(depths, resolution=target_resolution)
142
- # frames = center_crop_to_ratio(frames, resolution=target_resolution)
143
  pose_s, pose_t, K = self.get_poses(opts, depths, num_frames=opts.video_length)
144
  warped_images = []
145
  masks = []
 
111
 
112
  def infer_gradual(self, opts):
113
  frames = read_video_frames(
114
+ opts.video_path, opts.video_length, opts.stride, opts.max_res
115
  )
116
  vr = VideoReader(opts.video_path, ctx=cpu(0))
117
  frame_shape = vr[0].shape # (H, W, 3)
118
  ori_resolution = frame_shape[:2]
119
  print(f"==> original video shape: {frame_shape}")
120
+ target_resolution = get_center_crop_resolution(original_resoultion = ori_resolution, height = opts.height, width = opts.width)
121
+ print(f"==> target video shape resized: {target_resolution}")
122
 
123
  prompt = self.get_caption(opts, opts.video_path)
124
  depths = self.depth_estimater.infer(
 
138
  print(f"==> opts video length: {opts.video_length}")
139
  assert frames.shape[0] == opts.video_length
140
 
141
+ depths = center_crop_to_ratio(depths, resolution=target_resolution)
142
+ frames = center_crop_to_ratio(frames, resolution=target_resolution)
143
  pose_s, pose_t, K = self.get_poses(opts, depths, num_frames=opts.video_length)
144
  warped_images = []
145
  masks = []
inference/v2v_data/models/utils.py CHANGED
@@ -28,16 +28,57 @@ from decord import VideoReader, cpu
28
 
29
  from PIL import Image
30
 
31
- def read_video_frames(video_path, process_length, stride, max_res, dataset="open", height=576, width=1024):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def is_image(path):
33
  return any(path.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.bmp'])
34
-
35
- if is_image(video_path):
36
  print("==> Detected image. Loading as single-frame video:", video_path)
37
  img = Image.open(video_path).convert("RGB")
38
- # FIXME: hard coded
39
- width = width
40
- height = height
 
 
 
 
 
 
 
 
41
  img = img.resize((width, height), Image.BICUBIC)
42
  img = np.array(img).astype("float32") / 255.0 # [H, W, 3]
43
  frames = img[None, ...] # [1, H, W, 3]
@@ -49,9 +90,15 @@ def read_video_frames(video_path, process_length, stride, max_res, dataset="open
49
  vid = VideoReader(video_path, ctx=cpu(0))
50
  print("==> original video shape:", (len(vid), *vid.get_batch([0]).shape[1:]))
51
 
52
- # FIXME: hard coded
53
- width = width
54
- height = height
 
 
 
 
 
 
55
 
56
  vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
57
 
@@ -64,8 +111,6 @@ def read_video_frames(video_path, process_length, stride, max_res, dataset="open
64
 
65
  return frames
66
 
67
-
68
-
69
  def save_video(data, images_path, folder=None, fps=8):
70
  if isinstance(data, np.ndarray):
71
  tensor_data = (torch.from_numpy(data) * 255).to(torch.uint8)
 
28
 
29
  from PIL import Image
30
 
31
+ # def read_video_frames(video_path, process_length, target_fps, max_res, dataset="open"):
32
+ # if dataset == "open":
33
+ # print("==> processing video: ", video_path)
34
+ # vid = VideoReader(video_path, ctx=cpu(0))
35
+ # print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:]))
36
+ # original_height, original_width = vid.get_batch([0]).shape[1:3]
37
+ # height = round(original_height / 64) * 64
38
+ # width = round(original_width / 64) * 64
39
+ # if max(height, width) > max_res:
40
+ # scale = max_res / max(original_height, original_width)
41
+ # height = round(original_height * scale / 64) * 64
42
+ # width = round(original_width * scale / 64) * 64
43
+ # else:
44
+ # height = dataset_res_dict[dataset][0]
45
+ # width = dataset_res_dict[dataset][1]
46
+
47
+ # vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
48
+
49
+ # fps = vid.get_avg_fps() if target_fps == -1 else target_fps
50
+ # stride = round(vid.get_avg_fps() / fps)
51
+ # stride = max(stride, 1)
52
+ # frames_idx = list(range(0, len(vid), stride))
53
+ # print(
54
+ # f"==> downsampled shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}, with stride: {stride}"
55
+ # )
56
+ # if process_length != -1 and process_length < len(frames_idx):
57
+ # frames_idx = frames_idx[:process_length]
58
+ # print(
59
+ # f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}"
60
+ # )
61
+ # frames = vid.get_batch(frames_idx).asnumpy().astype("float32") / 255.0
62
+
63
+ # return frames, fps
64
+ def read_video_frames(video_path, process_length, stride, max_res, dataset="open"):
65
  def is_image(path):
66
  return any(path.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.bmp'])
67
+
68
+ if is_image(video_path):
69
  print("==> Detected image. Loading as single-frame video:", video_path)
70
  img = Image.open(video_path).convert("RGB")
71
+ original_width = img.width
72
+ original_height = img.height
73
+
74
+ height = round(original_height / 64) * 64
75
+ width = round(original_width / 64) * 64
76
+
77
+ if max(height, width) > max_res:
78
+ scale = max_res / max(original_height, original_width)
79
+ height = round(original_height * scale / 64) * 64
80
+ width = round(original_width * scale / 64) * 64
81
+
82
  img = img.resize((width, height), Image.BICUBIC)
83
  img = np.array(img).astype("float32") / 255.0 # [H, W, 3]
84
  frames = img[None, ...] # [1, H, W, 3]
 
90
  vid = VideoReader(video_path, ctx=cpu(0))
91
  print("==> original video shape:", (len(vid), *vid.get_batch([0]).shape[1:]))
92
 
93
+ original_height, original_width = vid.get_batch([0]).shape[1:3]
94
+
95
+ height = round(original_height / 64) * 64
96
+ width = round(original_width / 64) * 64
97
+
98
+ if max(height, width) > max_res:
99
+ scale = max_res / max(original_height, original_width)
100
+ height = round(original_height * scale / 64) * 64
101
+ width = round(original_width * scale / 64) * 64
102
 
103
  vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
104
 
 
111
 
112
  return frames
113
 
 
 
114
  def save_video(data, images_path, folder=None, fps=8):
115
  if isinstance(data, np.ndarray):
116
  tensor_data = (torch.from_numpy(data) * 255).to(torch.uint8)