BhumikaMak commited on
Commit
e33c04c
·
verified ·
1 Parent(s): 9d0349e

update: model

Browse files
Files changed (1) hide show
  1. yolov8.py +4 -4
yolov8.py CHANGED
@@ -171,9 +171,9 @@ def dff_nmf(image, target_lyr, n_components):
171
  rgb_img_float = np.float32(img) / 255.0
172
  input_tensor = torch.from_numpy(rgb_img_float).permute(2, 0, 1).unsqueeze(0).to(device)
173
 
174
- model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).to(device)
175
  dff= DeepFeatureFactorization(model=model,
176
- target_layer=model.model.model.model[int(target_lyr)],
177
  computation_on_concepts=None)
178
 
179
  concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
@@ -183,7 +183,7 @@ def dff_nmf(image, target_lyr, n_components):
183
  # "https://github.com/ultralytics/yolov5/raw/master/data/coco128.yaml" # URL to the YOLOv5 categories file
184
  #yaml_data = requests.get(yolov5_categories_url).text
185
  # labels = yaml.safe_load(yaml_data)['names'] # Parse the YAML file to get class names
186
- num_classes = model.model.model.model[-1].nc
187
  results = []
188
  for indx in range(explanations[0].shape[0]):
189
  upsampled_input = explanations[0][indx]
@@ -191,7 +191,7 @@ def dff_nmf(image, target_lyr, n_components):
191
  device = next(model.parameters()).device
192
  input_tensor = upsampled_input.unsqueeze(0)
193
  input_tensor = input_tensor.unsqueeze(1).repeat(1, 128, 1, 1)
194
- detection_lyr = model.model.model.model[-1]
195
  output1 = detection_lyr.m[0](input_tensor.to(device))
196
  objectness = output1[..., 4] # Objectness score (index 4)
197
  class_scores = output1[..., 5:] # Class scores (from index 5 onwards, representing 80 classes)
 
171
  rgb_img_float = np.float32(img) / 255.0
172
  input_tensor = torch.from_numpy(rgb_img_float).permute(2, 0, 1).unsqueeze(0).to(device)
173
 
174
+ model = YOLO('yolov8s.pt')
175
  dff= DeepFeatureFactorization(model=model,
176
+ target_layer=model.model.model[int(target_lyr)],
177
  computation_on_concepts=None)
178
 
179
  concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
 
183
  # "https://github.com/ultralytics/yolov5/raw/master/data/coco128.yaml" # URL to the YOLOv5 categories file
184
  #yaml_data = requests.get(yolov5_categories_url).text
185
  # labels = yaml.safe_load(yaml_data)['names'] # Parse the YAML file to get class names
186
+ num_classes = model.model.model[-1].nc
187
  results = []
188
  for indx in range(explanations[0].shape[0]):
189
  upsampled_input = explanations[0][indx]
 
191
  device = next(model.parameters()).device
192
  input_tensor = upsampled_input.unsqueeze(0)
193
  input_tensor = input_tensor.unsqueeze(1).repeat(1, 128, 1, 1)
194
+ detection_lyr = model.model.model[-1]
195
  output1 = detection_lyr.m[0](input_tensor.to(device))
196
  objectness = output1[..., 4] # Objectness score (index 4)
197
  class_scores = output1[..., 5:] # Class scores (from index 5 onwards, representing 80 classes)