reab5555 commited on
Commit
755e5aa
·
verified ·
1 Parent(s): 6b3604d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -10
app.py CHANGED
@@ -104,18 +104,42 @@ def analyze_image(image):
104
 
105
 
106
  def show_mask(mask, ax, random_color=False):
107
- if random_color:
108
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
109
- else:
110
- color = np.array([1.0, 0.0, 0.0, 0.5])
111
-
112
- if len(mask.shape) == 4:
113
- mask = mask[0, 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- mask_image = np.zeros((*mask.shape, 4), dtype=np.float32)
116
- mask_image[mask > 0] = color
117
 
118
- ax.imshow(mask_image)
 
 
 
 
 
119
 
120
 
121
  def process_image_detection(image, target_label, surprise_rating):
 
104
 
105
 
106
  def show_mask(mask, ax, random_color=False):
107
+ try:
108
+ # Debug print to understand mask type
109
+ print(f"show_mask input type: {type(mask)}")
110
+
111
+ # Convert mask if it's a tuple
112
+ if isinstance(mask, tuple):
113
+ if len(mask) > 0 and mask[0] is not None:
114
+ mask = mask[0]
115
+ else:
116
+ raise ValueError("Invalid mask tuple")
117
+
118
+ # Convert torch tensor to numpy if needed
119
+ if torch.is_tensor(mask):
120
+ mask = mask.cpu().numpy()
121
+
122
+ # Handle 4D tensor/array case
123
+ if len(mask.shape) == 4:
124
+ mask = mask[0, 0]
125
+ # Handle 3D tensor/array case
126
+ elif len(mask.shape) == 3:
127
+ mask = mask[0]
128
+
129
+ if random_color:
130
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
131
+ else:
132
+ color = np.array([1.0, 0.0, 0.0, 0.5])
133
 
134
+ mask_image = np.zeros((*mask.shape, 4), dtype=np.float32)
135
+ mask_image[mask > 0] = color
136
 
137
+ ax.imshow(mask_image)
138
+
139
+ except Exception as e:
140
+ print(f"show_mask error: {str(e)}")
141
+ print(f"mask shape: {getattr(mask, 'shape', 'no shape')}")
142
+ raise
143
 
144
 
145
  def process_image_detection(image, target_label, surprise_rating):