File size: 2,327 Bytes
097a740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# 13_threshold_calibration.py – dla multilabel v3

import json
import numpy as np
import torch
from datasets import load_from_disk
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer
from sklearn.metrics import f1_score
from pathlib import Path

# === Ścieżki v3
MODEL_DIR = Path("models/multilabel")
DATASET_DIR = Path("data/processed/dataset_multilabel_top30")
TOP_RULES_PATH = Path("data/metadata/top_rules.json")
OUTPUT_PATH = MODEL_DIR / "thresholds.json"

# === Wczytaj listę reguł
with open(TOP_RULES_PATH, encoding="utf-8") as f:
    labels = json.load(f)
label_count = len(labels)

# === Model i tokenizer
model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_DIR.resolve()))
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR.resolve()))

# === Trener z BCE loss
class MultilabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.BCEWithLogitsLoss()
        loss = loss_fct(logits, labels.float())
        return (loss, outputs) if return_outputs else loss

trainer = MultilabelTrainer(model=model)

# === Walidacja
ds = load_from_disk(str(DATASET_DIR.resolve()))
val_dataset = ds["validation"]

# === Predykcja
print("🔍 Generowanie predykcji na zbiorze walidacyjnym...")
predictions = trainer.predict(val_dataset)
logits = torch.tensor(predictions.predictions)
probs = torch.sigmoid(logits).numpy()
y_true = predictions.label_ids

# === Kalibracja
print("⚙️ Kalibracja progów dla każdej reguły...")
thresholds = {}
search_space = np.arange(0.05, 0.96, 0.05)

for i, label in enumerate(labels):
    best_f1 = 0.0
    best_thresh = 0.5
    for t in search_space:
        y_pred = (probs[:, i] > t).astype(int)
        score = f1_score(y_true[:, i], y_pred, zero_division=0)
        if score > best_f1:
            best_f1 = score
            best_thresh = round(t, 3)
    thresholds[label] = best_thresh
    print(f"📈 {label}: próg={best_thresh} (f1={best_f1:.4f})")

# === Zapis progów
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
    json.dump(thresholds, f, indent=2)

print(f"\n✅ Zapisano progi do {OUTPUT_PATH}")