Muhammad Taqi Raza commited on
Commit
0cc03a7
Β·
1 Parent(s): 0d2f841

print shapes

Browse files
inference/cli_demo_camera_i2v_pcd.py CHANGED
@@ -75,15 +75,20 @@ def maxpool_mask_tensor(mask_tensor):
75
  """
76
  T, H, W = mask_tensor.shape
77
  assert T % 12 == 0, "T must be divisible by 12 (e.g., 48)"
78
- assert H % 30 == 0 and W % 45 == 0, "H and W must be divisible by 30 and 45"
 
 
 
 
 
79
 
80
  # Reshape to (B=T, C=1, H, W) for 2D spatial pooling
81
  x = mask_tensor.unsqueeze(1).float() # (T, 1, H, W)
82
- x_pooled = F.max_pool2d(x, kernel_size=(H // 30, W // 45)) # β†’ (T, 1, 30, 45)
83
 
84
  # Temporal pooling: reshape to (12, T//12, 30, 45) and max along dim=1
85
  t_groups = T // 12
86
- x_pooled = x_pooled.view(12, t_groups, 30, 45)
87
  pooled_mask = torch.amax(x_pooled, dim=1) # β†’ (12, 30, 45)
88
 
89
  # Add a zero frame at the beginning: shape (1, 30, 45)
@@ -105,15 +110,19 @@ def avgpool_mask_tensor(mask_tensor):
105
  """
106
  T, H, W = mask_tensor.shape
107
  assert T % 12 == 0, "T must be divisible by 12 (e.g., 48)"
108
- assert H % 30 == 0 and W % 45 == 0, "H and W must be divisible by 30 and 45"
 
 
 
 
109
 
110
  # Spatial average pooling
111
  x = mask_tensor.unsqueeze(1).float() # (T, 1, H, W)
112
- x_pooled = F.avg_pool2d(x, kernel_size=(H // 30, W // 45)) # β†’ (T, 1, 30, 45)
113
 
114
  # Temporal pooling
115
  t_groups = T // 12
116
- x_pooled = x_pooled.view(12, t_groups, 30, 45)
117
  pooled_avg = torch.mean(x_pooled, dim=1) # β†’ (12, 30, 45)
118
 
119
  # Threshold: keep only when > 0.5
 
75
  """
76
  T, H, W = mask_tensor.shape
77
  assert T % 12 == 0, "T must be divisible by 12 (e.g., 48)"
78
+ # assert H % 30 == 0 and W % 45 == 0, "H and W must be divisible by 30 and 45"
79
+ assert H % 8 == 0 and W % 8 == 0, "H and W must be divisible by 8 for spatial pooling"
80
+
81
+ downsampling_factor_h = H // 8
82
+ downsampling_factor_w = W // 8
83
+
84
 
85
  # Reshape to (B=T, C=1, H, W) for 2D spatial pooling
86
  x = mask_tensor.unsqueeze(1).float() # (T, 1, H, W)
87
+ x_pooled = F.max_pool2d(x, kernel_size=(H // downsampling_factor_h, W // downsampling_factor_w)) # β†’ (T, 1, 30, 45)
88
 
89
  # Temporal pooling: reshape to (12, T//12, 30, 45) and max along dim=1
90
  t_groups = T // 12
91
+ x_pooled = x_pooled.view(12, t_groups, downsampling_factor_h, downsampling_factor_w)
92
  pooled_mask = torch.amax(x_pooled, dim=1) # β†’ (12, 30, 45)
93
 
94
  # Add a zero frame at the beginning: shape (1, 30, 45)
 
110
  """
111
  T, H, W = mask_tensor.shape
112
  assert T % 12 == 0, "T must be divisible by 12 (e.g., 48)"
113
+ # assert H % 30 == 0 and W % 45 == 0, "H and W must be divisible by 30 and 45"
114
+ assert H % 8 == 0 and W % 8 == 0, "H and W must be divisible by 8 for spatial pooling"
115
+
116
+ downsampling_factor_h = H // 8
117
+ downsampling_factor_w = W // 8
118
 
119
  # Spatial average pooling
120
  x = mask_tensor.unsqueeze(1).float() # (T, 1, H, W)
121
+ x_pooled = F.avg_pool2d(x, kernel_size=(H // downsampling_factor_h, W // downsampling_factor_w)) # β†’ (T, 1, 30, 45)
122
 
123
  # Temporal pooling
124
  t_groups = T // 12
125
+ x_pooled = x_pooled.view(12, t_groups, downsampling_factor_h, downsampling_factor_w)
126
  pooled_avg = torch.mean(x_pooled, dim=1) # β†’ (12, 30, 45)
127
 
128
  # Threshold: keep only when > 0.5