|
|
|
|
|
import os |
|
import json |
|
import numpy as np |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification, |
|
TrainingArguments, |
|
Trainer, |
|
EarlyStoppingCallback, |
|
) |
|
from datasets import load_from_disk |
|
from torch.utils.data import default_collate |
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score |
|
|
|
|
|
DATA_PATH = "data/processed/dataset_multilabel_top30" |
|
OUTPUT_DIR = "models/multilabel" |
|
MODEL_NAME = "microsoft/codebert-base" |
|
NUM_LABELS = 30 |
|
NUM_EPOCHS = 12 |
|
SEED = 42 |
|
|
|
|
|
print("📂 Ładowanie danych i tokenizera...") |
|
ds = load_from_disk(DATA_PATH) |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
print("🧠 Inicjalizacja modelu...") |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
MODEL_NAME, |
|
num_labels=NUM_LABELS, |
|
problem_type="multi_label_classification" |
|
) |
|
|
|
|
|
def compute_metrics(pred): |
|
logits, labels = pred |
|
probs = 1 / (1 + np.exp(-logits)) |
|
preds = (probs > 0.5).astype(int) |
|
return { |
|
"accuracy": accuracy_score(labels, preds), |
|
"f1": f1_score(labels, preds, average="micro"), |
|
"precision": precision_score(labels, preds, average="micro"), |
|
"recall": recall_score(labels, preds, average="micro"), |
|
} |
|
|
|
|
|
def collate_fn(batch): |
|
batch = default_collate(batch) |
|
batch["labels"] = batch["labels"].float() |
|
return batch |
|
|
|
|
|
args = TrainingArguments( |
|
output_dir=OUTPUT_DIR, |
|
evaluation_strategy="epoch", |
|
save_strategy="epoch", |
|
learning_rate=2e-5, |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=8, |
|
num_train_epochs=NUM_EPOCHS, |
|
weight_decay=0.01, |
|
load_best_model_at_end=True, |
|
save_total_limit=2, |
|
seed=SEED, |
|
logging_dir=os.path.join(OUTPUT_DIR, "logs"), |
|
logging_steps=50, |
|
metric_for_best_model="f1", |
|
greater_is_better=True, |
|
report_to="none" |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=args, |
|
train_dataset=ds["train"].with_format("torch"), |
|
eval_dataset=ds["validation"].with_format("torch"), |
|
tokenizer=tokenizer, |
|
compute_metrics=compute_metrics, |
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], |
|
data_collator=collate_fn, |
|
) |
|
|
|
|
|
print("🚀 Start treningu...") |
|
trainer.train() |
|
|
|
|
|
print("💾 Zapisuję model i logi...") |
|
trainer.save_model(OUTPUT_DIR) |
|
|
|
log_path = os.path.join(OUTPUT_DIR, "training_log.json") |
|
with open(log_path, "w", encoding="utf-8") as f: |
|
json.dump(trainer.state.log_history, f, indent=2) |
|
|
|
print(f"📝 Zapisano log treningu do {log_path}") |
|
print("✅ Gotowe.") |
|
|