File size: 2,814 Bytes
a48b15f
 
ce6be70
f3cd231
e64ca4e
00b5438
f3cd231
0f7de99
a48b15f
 
 
 
e64ca4e
a48b15f
 
f3cd231
a48b15f
0f7de99
a48b15f
 
 
00b5438
a48b15f
00b5438
 
 
 
874e761
a48b15f
 
 
 
 
d2471f2
ce6be70
d2471f2
ce6be70
0f7de99
 
a48b15f
 
 
0f7de99
a48b15f
 
 
 
00b5438
 
 
 
 
 
d2471f2
00b5438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75b9622
 
 
 
 
a48b15f
0f7de99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np

from lmsim.metrics import Metrics, CAPA, EC

from src.dataloading import load_run_data_cached
from src.utils import softmax, one_hot

def load_data_and_compute_similarities(models: list[str], dataset: str, metric_name: str) -> np.array:
    # Load data
    probs = []
    gts = []
    for model in models:
        model_probs, model_gt = load_run_data_cached(model, dataset)
        probs.append(model_probs)
        gts.append(model_gt)
    
    # Compute pairwise similarities
    similarities = compute_pairwise_similarities(metric_name, probs, gts)
    return similarities


def compute_similarity(metric: Metrics, outputs_a: list[np.array], outputs_b: list[np.array], gt: list[int],) -> float:
    # Check that the models have the same number of responses
    assert len(outputs_a) == len(outputs_b) == len(gt), f"Models must have the same number of responses: {len(outputs_a)} != {len(outputs_b)} != {len(gt)}" 
    
    # Compute similarity values
    similarity = metric.compute_k(outputs_a, outputs_b, gt)
    
    return similarity


def compute_pairwise_similarities(metric_name: str, probs: list[list[np.array]], gts: list[list[int]]) -> np.array:
    # Select chosen metric
    if metric_name == "CAPA":
        metric = CAPA()
    elif metric_name == "CAPA (det.)":
        metric = CAPA(prob=False)
        # Convert probabilities to one-hot
        probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
    elif metric_name == "Error Consistency":
        metric = EC()
    else:
        raise ValueError(f"Invalid metric: {metric_name}") 

    similarities = np.zeros((len(probs), len(probs)))
    for i in range(len(probs)):
        for j in range(i, len(probs)):
            outputs_a = probs[i]
            outputs_b = probs[j]
            gt_a = gts[i]
            gt_b = gts[j]

            # Format softmax outputs
            if metric_name == "CAPA":
                outputs_a = [softmax(logits) for logits in outputs_a]
                outputs_b = [softmax(logits) for logits in outputs_b]

            # Assert that the ground truth index is the same
            indices_to_remove = []
            if gt_a != gt_b:
                for idx, (a, b) in enumerate(zip(gt_a, gt_b)):
                    if a != b:
                        indices_to_remove.append(idx)
                for idx in sorted(indices_to_remove, reverse=True):
                    del outputs_a[idx]
                    del outputs_b[idx]
                    del gt_a[idx]
                    del gt_b[idx] 

            try:
                similarities[i, j] = compute_similarity(metric, outputs_a, outputs_b, gt_a)
            except Exception as e:
                similarities[i, j] = np.nan
        
            similarities[j, i] = similarities[i, j]
    return similarities