multilabel-dockerfile-model / scripts /13.2_threshold_calibration.py
LeeSek's picture
Add scripts
097a740 verified
# 13_threshold_calibration.py – dla multilabel v3
import json
import numpy as np
import torch
from datasets import load_from_disk
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer
from sklearn.metrics import f1_score
from pathlib import Path
# === Ścieżki v3
MODEL_DIR = Path("models/multilabel")
DATASET_DIR = Path("data/processed/dataset_multilabel_top30")
TOP_RULES_PATH = Path("data/metadata/top_rules.json")
OUTPUT_PATH = MODEL_DIR / "thresholds.json"
# === Wczytaj listę reguł
with open(TOP_RULES_PATH, encoding="utf-8") as f:
labels = json.load(f)
label_count = len(labels)
# === Model i tokenizer
model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_DIR.resolve()))
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR.resolve()))
# === Trener z BCE loss
class MultilabelTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
loss_fct = torch.nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels.float())
return (loss, outputs) if return_outputs else loss
trainer = MultilabelTrainer(model=model)
# === Walidacja
ds = load_from_disk(str(DATASET_DIR.resolve()))
val_dataset = ds["validation"]
# === Predykcja
print("🔍 Generowanie predykcji na zbiorze walidacyjnym...")
predictions = trainer.predict(val_dataset)
logits = torch.tensor(predictions.predictions)
probs = torch.sigmoid(logits).numpy()
y_true = predictions.label_ids
# === Kalibracja
print("⚙️ Kalibracja progów dla każdej reguły...")
thresholds = {}
search_space = np.arange(0.05, 0.96, 0.05)
for i, label in enumerate(labels):
best_f1 = 0.0
best_thresh = 0.5
for t in search_space:
y_pred = (probs[:, i] > t).astype(int)
score = f1_score(y_true[:, i], y_pred, zero_division=0)
if score > best_f1:
best_f1 = score
best_thresh = round(t, 3)
thresholds[label] = best_thresh
print(f"📈 {label}: próg={best_thresh} (f1={best_f1:.4f})")
# === Zapis progów
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
json.dump(thresholds, f, indent=2)
print(f"\n✅ Zapisano progi do {OUTPUT_PATH}")