# 13_threshold_calibration.py – dla multilabel v3 import json import numpy as np import torch from datasets import load_from_disk from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer from sklearn.metrics import f1_score from pathlib import Path # === Ścieżki v3 MODEL_DIR = Path("models/multilabel") DATASET_DIR = Path("data/processed/dataset_multilabel_top30") TOP_RULES_PATH = Path("data/metadata/top_rules.json") OUTPUT_PATH = MODEL_DIR / "thresholds.json" # === Wczytaj listę reguł with open(TOP_RULES_PATH, encoding="utf-8") as f: labels = json.load(f) label_count = len(labels) # === Model i tokenizer model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_DIR.resolve())) tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR.resolve())) # === Trener z BCE loss class MultilabelTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.pop("labels") outputs = model(**inputs) logits = outputs.logits loss_fct = torch.nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels.float()) return (loss, outputs) if return_outputs else loss trainer = MultilabelTrainer(model=model) # === Walidacja ds = load_from_disk(str(DATASET_DIR.resolve())) val_dataset = ds["validation"] # === Predykcja print("🔍 Generowanie predykcji na zbiorze walidacyjnym...") predictions = trainer.predict(val_dataset) logits = torch.tensor(predictions.predictions) probs = torch.sigmoid(logits).numpy() y_true = predictions.label_ids # === Kalibracja print("⚙️ Kalibracja progów dla każdej reguły...") thresholds = {} search_space = np.arange(0.05, 0.96, 0.05) for i, label in enumerate(labels): best_f1 = 0.0 best_thresh = 0.5 for t in search_space: y_pred = (probs[:, i] > t).astype(int) score = f1_score(y_true[:, i], y_pred, zero_division=0) if score > best_f1: best_f1 = score best_thresh = round(t, 3) thresholds[label] = best_thresh print(f"📈 {label}: próg={best_thresh} (f1={best_f1:.4f})") # === Zapis progów OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True) with open(OUTPUT_PATH, "w", encoding="utf-8") as f: json.dump(thresholds, f, indent=2) print(f"\n✅ Zapisano progi do {OUTPUT_PATH}")