BhumikaMak commited on
Commit
136ee37
·
1 Parent(s): 86915f5

Fix: Revert -12.

Browse files
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -10,7 +10,6 @@ from pytorch_grad_cam import EigenCAM
10
  from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
11
  from PIL import Image
12
  import gradio as gr
13
- from ultralytics import YOLO
14
 
15
  # Global Color Palette
16
  COLORS = np.random.uniform(0, 255, size=(80, 3))
@@ -46,7 +45,7 @@ def load_yolo_model(version="yolov5"):
46
  if version == "yolov5":
47
  model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
48
  elif version == "yolov8":
49
- model = YOLO("yolov8n.pt") # YOLOv8 is part of the ultralytics library
50
  elif version == "yolov10":
51
  model = torch.hub.load('ultralytics/yolov5', 'yolov5m', pretrained=True) # Placeholder for YOLOv10 (use an appropriate version if available)
52
  else:
@@ -56,6 +55,7 @@ def load_yolo_model(version="yolov5"):
56
  model.cpu()
57
  return model
58
 
 
59
  # Main function for Grad-CAM visualization
60
  def process_image(image, yolo_versions=["yolov5"]):
61
  image = np.array(image)
@@ -74,12 +74,7 @@ def process_image(image, yolo_versions=["yolov5"]):
74
  for yolo_version in yolo_versions:
75
  # Load the model based on YOLO version
76
  model = load_yolo_model(yolo_version)
77
-
78
- # YOLOv8: Extract last layer by model.model[-1] would not work; use the following:
79
- if isinstance(model.model, torch.nn.Sequential):
80
- target_layers = [model.model[-1]] # This assumes model layers are in a Sequential container
81
- else:
82
- target_layers = [model.model[-2]] # Use an appropriate layer
83
 
84
  # Run YOLO detection
85
  results = model([rgb_img])
@@ -105,21 +100,20 @@ def process_image(image, yolo_versions=["yolov5"]):
105
 
106
  return result_images
107
 
108
- # Gradio interface
109
  interface = gr.Interface(
110
  fn=process_image,
111
  inputs=[
112
  gr.Image(type="pil", label="Upload an Image"),
113
  gr.CheckboxGroup(
114
- choices=["yolov5", "yolov8", "yolov10"],
115
  value=["yolov5"], # Set the default value (YOLOv5 checked by default)
116
  label="Select Model(s)",
117
  )
118
  ],
119
- outputs=gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
120
- title="Visualizing the key image features that drive decisions with our explainable AI tool.",
121
  description="XAI: Upload an image to visualize object detection of your models.."
122
  )
123
 
124
  if __name__ == "__main__":
125
- interface.launch()
 
10
  from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
11
  from PIL import Image
12
  import gradio as gr
 
13
 
14
  # Global Color Palette
15
  COLORS = np.random.uniform(0, 255, size=(80, 3))
 
45
  if version == "yolov5":
46
  model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
47
  elif version == "yolov8":
48
+ model = torch.hub.load('ultralytics/yolov5:v7.0', 'yolov5', pretrained=True) # YOLOv8 is part of the yolov5 repo starting from v7.0
49
  elif version == "yolov10":
50
  model = torch.hub.load('ultralytics/yolov5', 'yolov5m', pretrained=True) # Placeholder for YOLOv10 (use an appropriate version if available)
51
  else:
 
55
  model.cpu()
56
  return model
57
 
58
+ # Main function for Grad-CAM visualization
59
  # Main function for Grad-CAM visualization
60
  def process_image(image, yolo_versions=["yolov5"]):
61
  image = np.array(image)
 
74
  for yolo_version in yolo_versions:
75
  # Load the model based on YOLO version
76
  model = load_yolo_model(yolo_version)
77
+ target_layers = [model.model.model.model[-2]] # Assumes last layer is used for Grad-CAM
 
 
 
 
 
78
 
79
  # Run YOLO detection
80
  results = model([rgb_img])
 
100
 
101
  return result_images
102
 
 
103
  interface = gr.Interface(
104
  fn=process_image,
105
  inputs=[
106
  gr.Image(type="pil", label="Upload an Image"),
107
  gr.CheckboxGroup(
108
+ choices=["yolov5", "yolov8", "yolov10"],
109
  value=["yolov5"], # Set the default value (YOLOv5 checked by default)
110
  label="Select Model(s)",
111
  )
112
  ],
113
+ outputs = gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
114
+ title="Visualising the key image features that drive decisions with our explainable AI tool.",
115
  description="XAI: Upload an image to visualize object detection of your models.."
116
  )
117
 
118
  if __name__ == "__main__":
119
+ interface.launch()