Joschka Strueber commited on
Commit
65ef274
·
1 Parent(s): ea91c80

[Fix] sim check for gts values

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. src/similarity.py +1 -1
app.py CHANGED
@@ -43,7 +43,7 @@ def create_heatmap(selected_models, selected_dataset, selected_metric):
43
  )
44
 
45
  # Customize plot
46
- plt.title(f"{selected_metric} Similarities for {selected_dataset}", fontsize=16)
47
  plt.xlabel("Models", fontsize=14)
48
  plt.ylabel("Models", fontsize=14)
49
  plt.xticks(rotation=45, ha='right')
 
43
  )
44
 
45
  # Customize plot
46
+ plt.title(f"{selected_metric} for {selected_dataset}", fontsize=16)
47
  plt.xlabel("Models", fontsize=14)
48
  plt.ylabel("Models", fontsize=14)
49
  plt.xticks(rotation=45, ha='right')
src/similarity.py CHANGED
@@ -27,7 +27,7 @@ def compute_similarity(metric: Metrics, probs_a: list[np.array], gt_a: list[int]
27
  output_b = []
28
  gt = []
29
  for i in range(len(probs_a)):
30
- if gt_a == gt_b:
31
  output_a.append(probs_a[i])
32
  output_b.append(probs_b[i])
33
  gt.append(gt_a[i])
 
27
  output_b = []
28
  gt = []
29
  for i in range(len(probs_a)):
30
+ if gt_a[i] == gt_b[i]:
31
  output_a.append(probs_a[i])
32
  output_b.append(probs_b[i])
33
  gt.append(gt_a[i])