|
|
|
|
|
import json |
|
import numpy as np |
|
import torch |
|
from datasets import load_from_disk |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer |
|
from sklearn.metrics import f1_score |
|
from pathlib import Path |
|
|
|
|
|
MODEL_DIR = Path("models/multilabel") |
|
DATASET_DIR = Path("data/processed/dataset_multilabel_top30") |
|
TOP_RULES_PATH = Path("data/metadata/top_rules.json") |
|
OUTPUT_PATH = MODEL_DIR / "thresholds.json" |
|
|
|
|
|
with open(TOP_RULES_PATH, encoding="utf-8") as f: |
|
labels = json.load(f) |
|
label_count = len(labels) |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_DIR.resolve())) |
|
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR.resolve())) |
|
|
|
|
|
class MultilabelTrainer(Trainer): |
|
def compute_loss(self, model, inputs, return_outputs=False): |
|
labels = inputs.pop("labels") |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
loss_fct = torch.nn.BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels.float()) |
|
return (loss, outputs) if return_outputs else loss |
|
|
|
trainer = MultilabelTrainer(model=model) |
|
|
|
|
|
ds = load_from_disk(str(DATASET_DIR.resolve())) |
|
val_dataset = ds["validation"] |
|
|
|
|
|
print("🔍 Generowanie predykcji na zbiorze walidacyjnym...") |
|
predictions = trainer.predict(val_dataset) |
|
logits = torch.tensor(predictions.predictions) |
|
probs = torch.sigmoid(logits).numpy() |
|
y_true = predictions.label_ids |
|
|
|
|
|
print("⚙️ Kalibracja progów dla każdej reguły...") |
|
thresholds = {} |
|
search_space = np.arange(0.05, 0.96, 0.05) |
|
|
|
for i, label in enumerate(labels): |
|
best_f1 = 0.0 |
|
best_thresh = 0.5 |
|
for t in search_space: |
|
y_pred = (probs[:, i] > t).astype(int) |
|
score = f1_score(y_true[:, i], y_pred, zero_division=0) |
|
if score > best_f1: |
|
best_f1 = score |
|
best_thresh = round(t, 3) |
|
thresholds[label] = best_thresh |
|
print(f"📈 {label}: próg={best_thresh} (f1={best_f1:.4f})") |
|
|
|
|
|
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True) |
|
with open(OUTPUT_PATH, "w", encoding="utf-8") as f: |
|
json.dump(thresholds, f, indent=2) |
|
|
|
print(f"\n✅ Zapisano progi do {OUTPUT_PATH}") |
|
|