# scripts/10.2_train_multilabel_model.py import os import json import numpy as np from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EarlyStoppingCallback, ) from datasets import load_from_disk from torch.utils.data import default_collate from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score # === Konfiguracja DATA_PATH = "data/processed/dataset_multilabel_top30" OUTPUT_DIR = "models/multilabel" MODEL_NAME = "microsoft/codebert-base" NUM_LABELS = 30 NUM_EPOCHS = 12 SEED = 42 # === Ładowanie danych i tokenizera print("📂 Ładowanie danych i tokenizera...") ds = load_from_disk(DATA_PATH) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # === Model print("🧠 Inicjalizacja modelu...") model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, num_labels=NUM_LABELS, problem_type="multi_label_classification" ) # === Funkcja metryk def compute_metrics(pred): logits, labels = pred probs = 1 / (1 + np.exp(-logits)) # sigmoid preds = (probs > 0.5).astype(int) return { "accuracy": accuracy_score(labels, preds), "f1": f1_score(labels, preds, average="micro"), "precision": precision_score(labels, preds, average="micro"), "recall": recall_score(labels, preds, average="micro"), } # === Batch collator: wymuszenie float32 def collate_fn(batch): batch = default_collate(batch) batch["labels"] = batch["labels"].float() return batch # === Argumenty treningowe args = TrainingArguments( output_dir=OUTPUT_DIR, evaluation_strategy="epoch", save_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=NUM_EPOCHS, weight_decay=0.01, load_best_model_at_end=True, save_total_limit=2, seed=SEED, logging_dir=os.path.join(OUTPUT_DIR, "logs"), logging_steps=50, metric_for_best_model="f1", greater_is_better=True, report_to="none" ) # === Trener trainer = Trainer( model=model, args=args, train_dataset=ds["train"].with_format("torch"), eval_dataset=ds["validation"].with_format("torch"), tokenizer=tokenizer, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], data_collator=collate_fn, ) # === Trening print("🚀 Start treningu...") trainer.train() # === Zapis modelu i logów print("💾 Zapisuję model i logi...") trainer.save_model(OUTPUT_DIR) log_path = os.path.join(OUTPUT_DIR, "training_log.json") with open(log_path, "w", encoding="utf-8") as f: json.dump(trainer.state.log_history, f, indent=2) print(f"📝 Zapisano log treningu do {log_path}") print("✅ Gotowe.")