File size: 302 Bytes
00b5438
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import numpy as np


def softmax(logits: np.ndarray) -> np.ndarray:
        exp_logits = np.exp(logits - np.max(logits))
        return exp_logits / exp_logits.sum(axis=0)

def one_hot(probs: np.array) -> np.array:
    one_hot = np.zeros_like(probs)
    one_hot[np.argmax(probs)] = 1
    return one_hot