File size: 2,834 Bytes
a48b15f
 
0f7de99
f3cd231
00b5438
 
f3cd231
0f7de99
a48b15f
 
 
 
 
 
 
f3cd231
a48b15f
0f7de99
a48b15f
 
 
00b5438
a48b15f
00b5438
 
 
 
874e761
a48b15f
 
 
 
 
 
 
 
0f7de99
 
 
a48b15f
 
 
0f7de99
a48b15f
 
 
 
00b5438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75b9622
 
 
 
 
a48b15f
0f7de99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np

from lmsim.metrics import Metrics, Kappa_p, EC

from src.dataloading import load_run_data
from src.utils import softmax, one_hot

def load_data_and_compute_similarities(models: list[str], dataset: str, metric_name: str) -> np.array:
    # Load data
    probs = []
    gts = []
    for model in models:
        model_probs, model_gt = load_run_data(model, dataset)
        probs.append(model_probs)
        gts.append(model_gt)
    
    # Compute pairwise similarities
    similarities = compute_pairwise_similarities(metric_name, probs, gts)
    return similarities


def compute_similarity(metric: Metrics, outputs_a: list[np.array], outputs_b: list[np.array], gt: list[int],) -> float:
    # Check that the models have the same number of responses
    assert len(outputs_a) == len(outputs_b) == len(gt), f"Models must have the same number of responses: {len(outputs_a)} != {len(outputs_b)} != {len(gt)}" 
    
    # Compute similarity values
    similarity = metric.compute_k(outputs_a, outputs_b, gt)
    
    return similarity


def compute_pairwise_similarities(metric_name: str, probs: list[list[np.array]], gts: list[list[int]]) -> np.array:
    # Select chosen metric
    if metric_name == "Kappa_p (prob.)":
        metric = Kappa_p()
    elif metric_name == "Kappa_p (det.)":
        metric = Kappa_p(prob=False)
        # Convert probabilities to one-hot
        probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
    elif metric_name == "Error Consistency":
        metric = EC()
    else:
        raise ValueError(f"Invalid metric: {metric_name}") 

    similarities = np.zeros((len(probs), len(probs)))
    for i in range(len(probs)):
        for j in range(i, len(probs)):
            outputs_a = probs[i]
            outputs_b = probs[j]
            gt_a = gts[i]
            gt_b = gts[j]

            # Format softmax outputs
            if metric_name == "Kappa_p (prob.)":
                outputs_a = [softmax(logits) for logits in outputs_a]
                outputs_b = [softmax(logits) for logits in outputs_b]

            # Assert that the ground truth index is the same
            indices_to_remove = []
            if gt_a != gt_b:
                for idx, (a, b) in enumerate(zip(gt_a, gt_b)):
                    if a != b:
                        indices_to_remove.append(idx)
                for idx in sorted(indices_to_remove, reverse=True):
                    del outputs_a[idx]
                    del outputs_b[idx]
                    del gt_a[idx]
                    del gt_b[idx] 

            try:
                similarities[i, j] = compute_similarity(metric, outputs_a, outputs_b, gt_a)
            except Exception as e:
                similarities[i, j] = np.nan
        
            similarities[j, i] = similarities[i, j]
    return similarities