File size: 1,320 Bytes
39dbdf0
 
 
 
 
6da5491
39dbdf0
6da5491
 
39dbdf0
6da5491
 
 
 
 
 
39dbdf0
6da5491
 
39dbdf0
 
 
6da5491
 
39dbdf0
6da5491
 
39dbdf0
6da5491
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

def fine_tune_model(dataset, model_name, epochs, batch_size, learning_rate):
    # Load the pre-trained model for sequence classification
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

    # Define the training arguments
    training_args = TrainingArguments(
        output_dir='./results',                # Directory for storing results
        num_train_epochs=epochs,               # Number of training epochs
        per_device_train_batch_size=batch_size,  # Batch size for training
        learning_rate=learning_rate,          # Learning rate for the optimizer
        logging_dir='./logs',                  # Directory for storing logs
        logging_steps=10,                      # Log every 10 steps
    )

    # Initialize the Trainer with the model, arguments, and dataset
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset['train'],        # Training dataset
        eval_dataset=dataset['validation'],    # Validation dataset
    )

    # Train the model
    trainer.train()

    # Return a status message after training completes
    return {"status": "Training complete"}