BhumikaMak commited on
Commit
c8e3f60
·
1 Parent(s): ef20834

Fix: Revert -5.

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -45,9 +45,13 @@ def load_yolo_model(version="yolov5"):
45
  if version == "yolov3":
46
  model = torch.hub.load('ultralytics/yolov3', 'yolov3', pretrained=True)
47
  elif version == "yolov5":
48
- model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # Load yolov5 small model
 
 
49
  elif version == "yolov8":
50
- model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # YOLOv8 can be accessed via yolov5 repo
 
 
51
  else:
52
  raise ValueError(f"Unsupported YOLO version: {version}")
53
 
@@ -55,6 +59,7 @@ def load_yolo_model(version="yolov5"):
55
  model.cpu()
56
  return model
57
 
 
58
  # Main function for Grad-CAM visualization
59
  def process_image(image, yolo_versions=["yolov5"]):
60
  image = np.array(image)
@@ -73,12 +78,10 @@ def process_image(image, yolo_versions=["yolov5"]):
73
  for yolo_version in yolo_versions:
74
  # Load the model based on YOLO version
75
  model = load_yolo_model(yolo_version)
76
-
77
- # In YOLOv5, the last convolutional layer in the backbone is typically the one to use
78
- target_layers = [model.model.model[-2]] # Assuming last layer of the model is a Conv layer
79
-
80
  # Run YOLO detection
81
- results = model(rgb_img)
82
  boxes, colors, names = parse_detections(results)
83
  detections_img = draw_detections(boxes, colors, names, rgb_img.copy())
84
 
@@ -101,21 +104,20 @@ def process_image(image, yolo_versions=["yolov5"]):
101
 
102
  return result_images
103
 
104
- # Define the Gradio interface
105
  interface = gr.Interface(
106
  fn=process_image,
107
  inputs=[
108
  gr.Image(type="pil", label="Upload an Image"),
109
  gr.CheckboxGroup(
110
- choices=["yolov3", "yolov5", "yolov8"],
111
  value=["yolov5"], # Set the default value (YOLOv5 checked by default)
112
  label="Select Model(s)",
113
  )
114
  ],
115
  outputs = gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
116
  title="Visualising the key image features that drive decisions with our explainable AI tool.",
117
- description="XAI: Upload an image to visualize object detection of your models."
118
  )
119
 
120
  if __name__ == "__main__":
121
- interface.launch()
 
45
  if version == "yolov3":
46
  model = torch.hub.load('ultralytics/yolov3', 'yolov3', pretrained=True)
47
  elif version == "yolov5":
48
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
49
+ elif version == "yolov7":
50
+ model = torch.hub.load('WongKinYiu/yolov7', 'yolov7', pretrained=True)
51
  elif version == "yolov8":
52
+ model = torch.hub.load('ultralytics/yolov5:v7.0', 'yolov5', pretrained=True) # YOLOv8 is part of the yolov5 repo starting from v7.0
53
+ elif version == "yolov10":
54
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5m', pretrained=True) # Placeholder for YOLOv10 (use an appropriate version if available)
55
  else:
56
  raise ValueError(f"Unsupported YOLO version: {version}")
57
 
 
59
  model.cpu()
60
  return model
61
 
62
+ # Main function for Grad-CAM visualization
63
  # Main function for Grad-CAM visualization
64
  def process_image(image, yolo_versions=["yolov5"]):
65
  image = np.array(image)
 
78
  for yolo_version in yolo_versions:
79
  # Load the model based on YOLO version
80
  model = load_yolo_model(yolo_version)
81
+ target_layers = [model.model.model.model[-2]] # Assumes last layer is used for Grad-CAM
82
+
 
 
83
  # Run YOLO detection
84
+ results = model([rgb_img])
85
  boxes, colors, names = parse_detections(results)
86
  detections_img = draw_detections(boxes, colors, names, rgb_img.copy())
87
 
 
104
 
105
  return result_images
106
 
 
107
  interface = gr.Interface(
108
  fn=process_image,
109
  inputs=[
110
  gr.Image(type="pil", label="Upload an Image"),
111
  gr.CheckboxGroup(
112
+ choices=["yolov3", "yolov5", "yolov7", "yolov8", "yolov10"],
113
  value=["yolov5"], # Set the default value (YOLOv5 checked by default)
114
  label="Select Model(s)",
115
  )
116
  ],
117
  outputs = gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
118
  title="Visualising the key image features that drive decisions with our explainable AI tool.",
119
+ description="XAI: Upload an image to visualize object detection of your models.."
120
  )
121
 
122
  if __name__ == "__main__":
123
+ interface.launch()