BhumikaMak commited on
Commit
f9e81bd
·
1 Parent(s): 4345023

Fix: version dependency

Browse files
Files changed (1) hide show
  1. app.py +41 -54
app.py CHANGED
@@ -10,49 +10,33 @@ from pytorch_grad_cam import EigenCAM
10
  from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
11
  from PIL import Image
12
  import gradio as gr
13
- from ultralytics import YOLO
14
 
15
  # Global Color Palette
16
  COLORS = np.random.uniform(0, 255, size=(80, 3))
17
 
18
  # Function to parse YOLO detections
19
- def parse_detections(results, yolo_version):
 
20
  boxes, colors, names = [], [], []
21
- if yolo_version == "yolov5":
22
- detections = results.pandas().xyxy[0].to_dict()
23
- for i in range(len(detections["xmin"])):
24
- confidence = detections["confidence"][i]
25
- if confidence < 0.2:
26
- continue
27
- xmin, ymin = int(detections["xmin"][i]), int(detections["ymin"][i])
28
- xmax, ymax = int(detections["xmax"][i]), int(detections["ymax"][i])
29
- name, category = detections["name"][i], int(detections["class"][i])
30
- boxes.append((xmin, ymin, xmax, ymax))
31
- colors.append(COLORS[category])
32
- names.append(name)
33
- elif yolo_version == "yolov8":
34
- # For YOLOv8
35
- for result in results:
36
- print('resukt', result)
37
- for box in result.boxes:
38
- print('box', box)
39
- xmin, ymin, xmax, ymax = box[0].xyxy.tolist()
40
- confidence = box.conf.item()
41
- class_id = int(box.cls.item())
42
- if confidence > 0.2:
43
- boxes.append((xmin, ymin, xmax, ymax))
44
- colors.append(COLORS[class_id % len(COLORS)])
45
- names.append(result.names[class_id])
46
-
47
  return boxes, colors, names
48
 
49
  # Draw bounding boxes and labels
50
  def draw_detections(boxes, colors, names, img):
51
  for box, color, name in zip(boxes, colors, names):
52
- xmin, ymin, xmax, ymax = map(int, box)
53
- cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color.tolist(), 2)
54
  cv2.putText(img, name, (xmin, ymin - 5),
55
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, color.tolist(), 2,
56
  lineType=cv2.LINE_AA)
57
  return img
58
 
@@ -60,22 +44,22 @@ def draw_detections(boxes, colors, names, img):
60
  def load_yolo_model(version="yolov5"):
61
  if version == "yolov5":
62
  model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
63
- elif version == "yolov8":
64
- model = YOLO("yolov8n.pt") # Ensure you have this file available
65
  else:
66
  raise ValueError(f"Unsupported YOLO version: {version}")
67
 
68
  model.eval() # Set to evaluation mode
 
69
  return model
70
 
71
  def process_image(image, yolo_versions=["yolov5"]):
72
  image = np.array(image)
73
- image_resized = cv2.resize(image, (640, 640))
74
- rgb_img = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB) # Convert to RGB for display
 
75
 
76
  # Image transformation
77
  transform = transforms.ToTensor()
78
- tensor = transform(image_resized).unsqueeze(0)
79
 
80
  # Initialize list to store result images with captions
81
  result_images = []
@@ -84,24 +68,27 @@ def process_image(image, yolo_versions=["yolov5"]):
84
  for yolo_version in yolo_versions:
85
  # Load the model based on YOLO version
86
  model = load_yolo_model(yolo_version)
87
-
 
88
  # Run YOLO detection
89
- results = model([rgb_img]) # Ensure this is a list containing one image
90
-
91
- # Parse detections using updated function
92
- boxes, colors, names = parse_detections(results[0], yolo_version)
93
-
94
- detections_img = draw_detections(boxes, colors.copy(), names.copy(), rgb_img.copy())
95
 
96
  # Grad-CAM visualization
97
- target_layers = [model.model.model[-1]] # Use last layer as target layer for Grad-CAM
98
- cam = EigenCAM(model=model.model.model[-1], target_layers=target_layers)
99
- grayscale_cam = cam(tensor)[0]
100
-
101
- cam_image = show_cam_on_image(image_resized / 255.0, grayscale_cam.numpy(), use_rgb=True)
 
 
 
 
 
102
 
103
  # Concatenate images and prepare the caption
104
- final_image = np.hstack((rgb_img.copy(), cam_image))
105
  caption = f"Results using {yolo_version}"
106
  result_images.append((Image.fromarray(final_image), caption))
107
 
@@ -112,15 +99,15 @@ interface = gr.Interface(
112
  inputs=[
113
  gr.Image(type="pil", label="Upload an Image"),
114
  gr.CheckboxGroup(
115
- choices=["yolov5", "yolov8"],
116
  value=["yolov5"], # Set the default value (YOLOv5 checked by default)
117
  label="Select Model(s)",
118
  )
119
  ],
120
- outputs=gr.Gallery(label="Results", elem_id="gallery", rows=2),
121
- title="Visualizing Key Image Features with Explainable AI Tool",
122
- description="Upload an image to visualize object detection of your models."
123
  )
124
 
125
  if __name__ == "__main__":
126
- interface.launch()
 
10
  from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
11
  from PIL import Image
12
  import gradio as gr
 
13
 
14
  # Global Color Palette
15
  COLORS = np.random.uniform(0, 255, size=(80, 3))
16
 
17
  # Function to parse YOLO detections
18
+ def parse_detections(results):
19
+ detections = results.pandas().xyxy[0].to_dict()
20
  boxes, colors, names = [], [], []
21
+ for i in range(len(detections["xmin"])):
22
+ confidence = detections["confidence"][i]
23
+ if confidence < 0.2:
24
+ continue
25
+ xmin, ymin = int(detections["xmin"][i]), int(detections["ymin"][i])
26
+ xmax, ymax = int(detections["xmax"][i]), int(detections["ymax"][i])
27
+ name, category = detections["name"][i], int(detections["class"][i])
28
+ boxes.append((xmin, ymin, xmax, ymax))
29
+ colors.append(COLORS[category])
30
+ names.append(name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return boxes, colors, names
32
 
33
  # Draw bounding boxes and labels
34
  def draw_detections(boxes, colors, names, img):
35
  for box, color, name in zip(boxes, colors, names):
36
+ xmin, ymin, xmax, ymax = box
37
+ cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 2)
38
  cv2.putText(img, name, (xmin, ymin - 5),
39
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2,
40
  lineType=cv2.LINE_AA)
41
  return img
42
 
 
44
  def load_yolo_model(version="yolov5"):
45
  if version == "yolov5":
46
  model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
 
 
47
  else:
48
  raise ValueError(f"Unsupported YOLO version: {version}")
49
 
50
  model.eval() # Set to evaluation mode
51
+ model.cpu()
52
  return model
53
 
54
  def process_image(image, yolo_versions=["yolov5"]):
55
  image = np.array(image)
56
+ image = cv2.resize(image, (640, 640))
57
+ rgb_img = image.copy()
58
+ img_float = np.float32(image) / 255
59
 
60
  # Image transformation
61
  transform = transforms.ToTensor()
62
+ tensor = transform(img_float).unsqueeze(0)
63
 
64
  # Initialize list to store result images with captions
65
  result_images = []
 
68
  for yolo_version in yolo_versions:
69
  # Load the model based on YOLO version
70
  model = load_yolo_model(yolo_version)
71
+ target_layers = [model.model.model.model[-2]] # Assumes last layer is used for Grad-CAM
72
+
73
  # Run YOLO detection
74
+ results = model([rgb_img])
75
+ boxes, colors, names = parse_detections(results)
76
+ detections_img = draw_detections(boxes, colors, names, rgb_img.copy())
 
 
 
77
 
78
  # Grad-CAM visualization
79
+ cam = EigenCAM(model, target_layers)
80
+ grayscale_cam = cam(tensor)[0, :, :]
81
+ cam_image = show_cam_on_image(img_float, grayscale_cam, use_rgb=True)
82
+
83
+ # Renormalize Grad-CAM inside bounding boxes
84
+ renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
85
+ for x1, y1, x2, y2 in boxes:
86
+ renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
87
+ renormalized_cam = scale_cam_image(renormalized_cam)
88
+ renormalized_cam_image = show_cam_on_image(img_float, renormalized_cam, use_rgb=True)
89
 
90
  # Concatenate images and prepare the caption
91
+ final_image = np.hstack((rgb_img, cam_image, renormalized_cam_image))
92
  caption = f"Results using {yolo_version}"
93
  result_images.append((Image.fromarray(final_image), caption))
94
 
 
99
  inputs=[
100
  gr.Image(type="pil", label="Upload an Image"),
101
  gr.CheckboxGroup(
102
+ choices=["yolov5"],
103
  value=["yolov5"], # Set the default value (YOLOv5 checked by default)
104
  label="Select Model(s)",
105
  )
106
  ],
107
+ outputs = gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
108
+ title="Visualising the key image features that drive decisions with our explainable AI tool.",
109
+ description="XAI: Upload an image to visualize object detection of your models.."
110
  )
111
 
112
  if __name__ == "__main__":
113
+ interface.launch()