lm-similarity / src /utils.py
Joschka Strueber
[Fix] convert logits to softmax for kappa_p
00b5438
raw
history blame
302 Bytes
import numpy as np
def softmax(logits: np.ndarray) -> np.ndarray:
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / exp_logits.sum(axis=0)
def one_hot(probs: np.array) -> np.array:
one_hot = np.zeros_like(probs)
one_hot[np.argmax(probs)] = 1
return one_hot