Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch | |
| import numpy as np | |
| from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score | |
| class LabelWeightedBCELoss(nn.Module): | |
| """ | |
| Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution. | |
| Allows for the weighing of each probability distribution wrt loss. | |
| """ | |
| def __init__(self, label_weights: torch.Tensor, reduction="mean"): | |
| super().__init__() | |
| self.label_weights = label_weights | |
| match reduction: | |
| case "mean": | |
| self.reduction = torch.mean | |
| case "sum": | |
| self.reduction = torch.sum | |
| def _log(self, x: torch.Tensor) -> torch.Tensor: | |
| return torch.clamp_min(torch.log(x), -100) | |
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| losses = -self.label_weights * ( | |
| target * self._log(input) + (1 - target) * self._log(1 - input) | |
| ) | |
| return self.reduction(losses) | |
| # TODO: Code a onehot | |
| def calculate_metrics( | |
| pred, target, threshold=0.5, prefix="", multi_label=True | |
| ) -> dict[str, torch.Tensor]: | |
| target = target.detach().cpu().numpy() | |
| pred = pred.detach().cpu() | |
| pred = nn.functional.softmax(pred, dim=1) | |
| pred = pred.numpy() | |
| params = { | |
| "y_true": target if multi_label else target.argmax(1), | |
| "y_pred": np.array(pred > threshold, dtype=float) | |
| if multi_label | |
| else pred.argmax(1), | |
| "zero_division": 0, | |
| "average": "macro", | |
| } | |
| metrics = { | |
| "precision": precision_score(**params), | |
| "recall": recall_score(**params), | |
| "f1": f1_score(**params), | |
| "accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]), | |
| } | |
| return { | |
| prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items() | |
| } | |
| class EarlyStopping: | |
| def __init__(self, patience=0): | |
| self.patience = patience | |
| self.last_measure = np.inf | |
| self.consecutive_increase = 0 | |
| def step(self, val) -> bool: | |
| if self.last_measure <= val: | |
| self.consecutive_increase += 1 | |
| else: | |
| self.consecutive_increase = 0 | |
| self.last_measure = val | |
| return self.patience < self.consecutive_increase | |
| def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]: | |
| id2label = {str(i): label for i, label in enumerate(labels)} | |
| label2id = {label: str(i) for i, label in enumerate(labels)} | |
| return id2label, label2id | |
| def compute_hf_metrics(eval_pred): | |
| predictions = np.argmax(eval_pred.predictions, axis=1) | |
| return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions) | |