Spaces:
Sleeping
Sleeping
update: dff_nmf
Browse files
yolov8.py
CHANGED
@@ -162,7 +162,6 @@ class DeepFeatureFactorization:
|
|
162 |
|
163 |
|
164 |
|
165 |
-
|
166 |
def dff_nmf(image, target_lyr, n_components):
|
167 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
168 |
mean = [0.485, 0.456, 0.406] # Mean for RGB channels
|
@@ -171,77 +170,52 @@ def dff_nmf(image, target_lyr, n_components):
|
|
171 |
rgb_img_float = np.float32(img) / 255.0
|
172 |
input_tensor = torch.from_numpy(rgb_img_float).permute(2, 0, 1).unsqueeze(0).to(device)
|
173 |
|
174 |
-
model = YOLO('yolov8s.pt')
|
175 |
-
dff= DeepFeatureFactorization(model=model,
|
176 |
-
|
177 |
-
|
178 |
|
179 |
concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
-
#
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
objectness = output1[..., 4] # Objectness score (index 4)
|
197 |
-
class_scores = output1[..., 5:] # Class scores (from index 5 onwards, representing 80 classes)
|
198 |
-
objectness = torch.sigmoid(objectness)
|
199 |
-
class_scores = torch.sigmoid(class_scores)
|
200 |
-
confidence_mask = objectness > 0.5
|
201 |
-
objectness = objectness[confidence_mask]
|
202 |
-
class_scores = class_scores[confidence_mask]
|
203 |
-
scores, class_ids = class_scores.max(dim=-1) # Get max class score per cell
|
204 |
-
scores = scores * objectness # Adjust scores by objectness
|
205 |
-
boxes = output1[..., :4] # First 4 values are x1, y1, x2, y2
|
206 |
-
boxes = boxes[confidence_mask] # Filter boxes by confidence mask
|
207 |
-
fig, ax = plt.subplots(1, figsize=(8, 8))
|
208 |
-
ax.axis("off")
|
209 |
-
ax.imshow(torch.tensor(batch_explanations[0][indx]).cpu().numpy(), cmap="plasma") # Display image
|
210 |
-
top_score_idx = scores.argmax(dim=0) # Get the index of the max score
|
211 |
-
top_score = scores[top_score_idx].item()
|
212 |
-
top_class_id = class_ids[top_score_idx].item()
|
213 |
-
top_box = boxes[top_score_idx].cpu().numpy()
|
214 |
-
scale_factor = 16
|
215 |
-
x1, y1, x2, y2 = top_box
|
216 |
-
x1, y1, x2, y2 = x1 * scale_factor, y1 * scale_factor, x2 * scale_factor, y2 * scale_factor
|
217 |
-
rect = patches.Rectangle(
|
218 |
-
(x1, y1), x2 - x1, y2 - y1,
|
219 |
-
linewidth=2, edgecolor='r', facecolor='none')
|
220 |
ax.add_patch(rect)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
226 |
-
|
227 |
-
fig.canvas.draw() # Draw the canvas to make sure the image is rendered
|
228 |
-
image_array = np.array(fig.canvas.renderer.buffer_rgba()) # Convert to numpy array
|
229 |
-
print("____________image_arrya", image_array.shape)
|
230 |
-
image_resized = cv2.resize(image_array, (640, 640))
|
231 |
-
rgba_channels = cv2.split(image_resized)
|
232 |
-
alpha_channel = rgba_channels[3]
|
233 |
-
rgb_channels = np.stack(rgba_channels[:3], axis=-1)
|
234 |
-
#overlay_img = (alpha_channel[..., None] * image) + ((1 - alpha_channel[..., None]) * rgb_channels)
|
235 |
-
|
236 |
-
#temp = image_array.reshape((rgb_img_float.shape[0],rgb_img_float.shape[1]) )
|
237 |
-
#visualization = show_factorization_on_image(rgb_img_float, image_array.resize((rgb_img_float.shape)) , image_weight=0.3)
|
238 |
-
visualization = show_factorization_on_image(rgb_img_float, np.transpose(rgb_channels, (2, 0, 1)), image_weight=0.3)
|
239 |
-
results.append(visualization)
|
240 |
-
plt.clf()
|
241 |
-
#return image_array
|
242 |
-
|
243 |
|
244 |
-
return rgb_img_float, batch_explanations, results
|
245 |
|
246 |
|
247 |
def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):
|
|
|
162 |
|
163 |
|
164 |
|
|
|
165 |
def dff_nmf(image, target_lyr, n_components):
|
166 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
167 |
mean = [0.485, 0.456, 0.406] # Mean for RGB channels
|
|
|
170 |
rgb_img_float = np.float32(img) / 255.0
|
171 |
input_tensor = torch.from_numpy(rgb_img_float).permute(2, 0, 1).unsqueeze(0).to(device)
|
172 |
|
173 |
+
model = YOLO('yolov8s.pt') # Ensure the model is loaded correctly
|
174 |
+
dff = DeepFeatureFactorization(model=model,
|
175 |
+
target_layer=model.model.model[int(target_lyr)],
|
176 |
+
computation_on_concepts=None)
|
177 |
|
178 |
concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
|
179 |
|
180 |
+
# Getting predictions directly from YOLO
|
181 |
+
with torch.no_grad():
|
182 |
+
results = model(input_tensor)
|
183 |
+
|
184 |
+
# Post-processing to extract detections
|
185 |
+
boxes, scores, classes = results.xywh[0][:, :4], results.xywh[0][:, 4], results.xywh[0][:, 5]
|
186 |
+
boxes = boxes.cpu().numpy()
|
187 |
+
scores = scores.cpu().numpy()
|
188 |
+
classes = classes.cpu().numpy()
|
189 |
|
190 |
+
# Filter detections with confidence score > threshold (e.g., 0.5)
|
191 |
+
high_conf_boxes = boxes[scores > 0.5]
|
192 |
+
high_conf_classes = classes[scores > 0.5]
|
193 |
+
|
194 |
+
# Use the processed detections for visualization and further tasks
|
195 |
+
# Example visualization and output processing
|
196 |
+
fig, ax = plt.subplots(1, figsize=(8, 8))
|
197 |
+
ax.axis("off")
|
198 |
+
ax.imshow(rgb_img_float)
|
199 |
+
|
200 |
+
for box, cls in zip(high_conf_boxes, high_conf_classes):
|
201 |
+
x1, y1, x2, y2 = box
|
202 |
+
rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
|
203 |
+
linewidth=2, edgecolor='r', facecolor='none')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
ax.add_patch(rect)
|
205 |
+
ax.text(x1, y1, f"Class {cls}", color='r', fontsize=12, verticalalignment='top')
|
206 |
+
|
207 |
+
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
208 |
+
fig.canvas.draw()
|
209 |
+
image_array = np.array(fig.canvas.renderer.buffer_rgba())
|
210 |
+
image_resized = cv2.resize(image_array, (640, 640))
|
211 |
+
rgba_channels = cv2.split(image_resized)
|
212 |
+
alpha_channel = rgba_channels[3]
|
213 |
+
rgb_channels = np.stack(rgba_channels[:3], axis=-1)
|
214 |
|
215 |
+
visualization = show_factorization_on_image(rgb_img_float, np.transpose(rgb_channels, (2, 0, 1)), image_weight=0.3)
|
216 |
+
|
217 |
+
return rgb_img_float, batch_explanations, visualization
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
|
|
219 |
|
220 |
|
221 |
def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):
|