File size: 887 Bytes
00b5438
 
 
 
 
 
 
 
 
 
0a42e99
 
b16e2d1
0a42e99
 
 
b16e2d1
 
0a42e99
 
b16e2d1
 
 
0a42e99
 
c608f7f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np


def softmax(logits: np.ndarray) -> np.ndarray:
        exp_logits = np.exp(logits - np.max(logits))
        return exp_logits / exp_logits.sum(axis=0)

def one_hot(probs: np.array) -> np.array:
    one_hot = np.zeros_like(probs)
    one_hot[np.argmax(probs)] = 1
    return one_hot

def opt_to_index(s):
    if s.startswith("(") and s.endswith(")"):
        letter = s[1]  # Extract the letter inside the parentheses
        return ord(letter) - ord("A")  # Convert to zero-based index
    elif is_single_letter(s):
        return ord(s.upper()) - ord("A")
    else:
        raise ValueError("Invalid format")

def is_single_letter(s):
    return len(s) == 1 and s.isalpha()
    
def get_test_target(doc):
    if "target" in doc:
        return doc["target"], "target"
    elif "answer" in doc:
        return doc["answer"], "answer"
    else:
        return "", ""