lm-similarity / src /similarity.py
Joschka Strueber
[Ref, Fix] use cached list of usable models, convert logits to OneHot for EC as well
64b132e
raw
history blame
2.98 kB
import numpy as np
from lmsim.metrics import Metrics, CAPA, EC
from src.dataloading import load_run_data_cached
from src.utils import softmax, one_hot
def load_data_and_compute_similarities(models: list[str], dataset: str, metric_name: str) -> np.array:
# Load data
probs = []
gts = []
for model in models:
model_probs, model_gt = load_run_data_cached(model, dataset)
probs.append(model_probs)
gts.append(model_gt)
# Compute pairwise similarities
similarities = compute_pairwise_similarities(metric_name, probs, gts)
return similarities
def compute_similarity(metric: Metrics, outputs_a: list[np.array], outputs_b: list[np.array], gt: list[int],) -> float:
# Check that the models have the same number of responses
assert len(outputs_a) == len(outputs_b) == len(gt), f"Models must have the same number of responses: {len(outputs_a)} != {len(outputs_b)} != {len(gt)}"
# Compute similarity values
similarity = metric.compute_k(outputs_a, outputs_b, gt)
return similarity
def compute_pairwise_similarities(metric_name: str, probs: list[list[np.array]], gts: list[list[int]]) -> np.array:
# Select chosen metric
if metric_name == "CAPA":
metric = CAPA()
elif metric_name == "CAPA (det.)":
metric = CAPA(prob=False)
# Convert probabilities to one-hot
probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
elif metric_name == "Error Consistency":
probs = [[one_hot(p) for p in model_probs] for model_probs in probs]
metric = EC()
else:
raise ValueError(f"Invalid metric: {metric_name}")
similarities = np.zeros((len(probs), len(probs)))
for i in range(len(probs)):
for j in range(i, len(probs)):
outputs_a = probs[i]
outputs_b = probs[j]
gt_a = gts[i]
gt_b = gts[j]
# Format softmax outputs
if metric_name == "CAPA":
outputs_a = [softmax(logits) for logits in outputs_a]
outputs_b = [softmax(logits) for logits in outputs_b]
# Assert that the ground truth index is the same
indices_to_remove = []
if gt_a != gt_b:
for idx, (a, b) in enumerate(zip(gt_a, gt_b)):
if a != b:
indices_to_remove.append(idx)
for idx in sorted(indices_to_remove, reverse=True):
del outputs_a[idx]
del outputs_b[idx]
del gt_a[idx]
del gt_b[idx]
try:
similarities[i, j] = compute_similarity(metric, outputs_a, outputs_b, gt_a)
except Exception as e:
print(f"Failed to compute similarity between models {i} and {j}: {e}")
similarities[i, j] = np.nan
similarities[j, i] = similarities[i, j]
return similarities