File size: 2,792 Bytes
097a740 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
# scripts/10.2_train_multilabel_model.py
import os
import json
import numpy as np
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
EarlyStoppingCallback,
)
from datasets import load_from_disk
from torch.utils.data import default_collate
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
# === Konfiguracja
DATA_PATH = "data/processed/dataset_multilabel_top30"
OUTPUT_DIR = "models/multilabel"
MODEL_NAME = "microsoft/codebert-base"
NUM_LABELS = 30
NUM_EPOCHS = 12
SEED = 42
# === Ładowanie danych i tokenizera
print("📂 Ładowanie danych i tokenizera...")
ds = load_from_disk(DATA_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# === Model
print("🧠 Inicjalizacja modelu...")
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=NUM_LABELS,
problem_type="multi_label_classification"
)
# === Funkcja metryk
def compute_metrics(pred):
logits, labels = pred
probs = 1 / (1 + np.exp(-logits)) # sigmoid
preds = (probs > 0.5).astype(int)
return {
"accuracy": accuracy_score(labels, preds),
"f1": f1_score(labels, preds, average="micro"),
"precision": precision_score(labels, preds, average="micro"),
"recall": recall_score(labels, preds, average="micro"),
}
# === Batch collator: wymuszenie float32
def collate_fn(batch):
batch = default_collate(batch)
batch["labels"] = batch["labels"].float()
return batch
# === Argumenty treningowe
args = TrainingArguments(
output_dir=OUTPUT_DIR,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=NUM_EPOCHS,
weight_decay=0.01,
load_best_model_at_end=True,
save_total_limit=2,
seed=SEED,
logging_dir=os.path.join(OUTPUT_DIR, "logs"),
logging_steps=50,
metric_for_best_model="f1",
greater_is_better=True,
report_to="none"
)
# === Trener
trainer = Trainer(
model=model,
args=args,
train_dataset=ds["train"].with_format("torch"),
eval_dataset=ds["validation"].with_format("torch"),
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
data_collator=collate_fn,
)
# === Trening
print("🚀 Start treningu...")
trainer.train()
# === Zapis modelu i logów
print("💾 Zapisuję model i logi...")
trainer.save_model(OUTPUT_DIR)
log_path = os.path.join(OUTPUT_DIR, "training_log.json")
with open(log_path, "w", encoding="utf-8") as f:
json.dump(trainer.state.log_history, f, indent=2)
print(f"📝 Zapisano log treningu do {log_path}")
print("✅ Gotowe.")
|