LLMTrainingPro / train_model.py
Vishwas1's picture
Update train_model.py
791abc9 verified
raw
history blame
11.2 kB
# train_model.py (Training Script)
import argparse
from transformers import (
GPT2Config,
GPT2LMHeadModel,
BertConfig,
BertForSequenceClassification,
Trainer,
TrainingArguments,
AutoTokenizer,
DataCollatorForLanguageModeling,
DataCollatorWithPadding,
)
from datasets import load_dataset
import torch
import os
from huggingface_hub import login, HfApi, HfFolder
import logging
def setup_logging(log_file_path):
"""
Sets up logging to both console and a file.
"""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Create handlers
c_handler = logging.StreamHandler()
f_handler = logging.FileHandler(log_file_path)
c_handler.setLevel(logging.INFO)
f_handler.setLevel(logging.INFO)
# Create formatters and add to handlers
c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
c_handler.setFormatter(c_format)
f_handler.setFormatter(f_format)
# Add handlers to the logger
logger.addHandler(c_handler)
logger.addHandler(f_handler)
def parse_arguments():
"""
Parses command-line arguments.
"""
parser = argparse.ArgumentParser(description="Train a custom LLM.")
parser.add_argument("--task", type=str, required=True, choices=["generation", "classification"],
help="Task type: 'generation' or 'classification'")
parser.add_argument("--model_name", type=str, required=True, help="Name of the model")
parser.add_argument("--dataset_name", type=str, required=True, help="Name of the Hugging Face dataset (e.g., 'wikitext/wikitext-2-raw-v1')")
parser.add_argument("--num_layers", type=int, default=12, help="Number of hidden layers")
parser.add_argument("--attention_heads", type=int, default=1, help="Number of attention heads")
parser.add_argument("--hidden_size", type=int, default=64, help="Hidden size of the model")
parser.add_argument("--vocab_size", type=int, default=30000, help="Vocabulary size")
parser.add_argument("--sequence_length", type=int, default=512, help="Maximum sequence length")
args = parser.parse_args()
return args
def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
"""
Loads and tokenizes the dataset based on the task.
"""
logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
try:
if task == "generation":
# Check if dataset_name includes config
if '/' in dataset_name:
dataset, config = dataset_name.split('/', 1)
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train[:1%]', use_auth_token=True)
else:
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train[:1%]', use_auth_token=True)
logging.info("Dataset loaded successfully for generation task.")
def tokenize_function(examples):
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
elif task == "classification":
if '/' in dataset_name:
dataset, config = dataset_name.split('/', 1)
dataset = load_dataset("stanfordnlp/imdb", split='train[:1%]', use_auth_token=True)
else:
dataset = load_dataset("stanfordnlp/imdb", split='train[:1%]', use_auth_token=True)
logging.info("Dataset loaded successfully for classification task.")
# Assuming the dataset has 'text' and 'label' columns
def tokenize_function(examples):
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
else:
raise ValueError("Unsupported task type")
# Shuffle and select a subset
tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
logging.info("Dataset tokenization complete.")
return tokenized_datasets
except Exception as e:
logging.error(f"Error loading or tokenizing dataset: {str(e)}")
raise e
def initialize_model(task, model_name, vocab_size, sequence_length, hidden_size, num_layers, attention_heads):
"""
Initializes the model configuration and model based on the task.
"""
logging.info(f"Initializing model for task '{task}'...")
try:
if task == "generation":
config = GPT2Config(
vocab_size=vocab_size,
n_positions=sequence_length,
n_ctx=sequence_length,
n_embd=hidden_size,
num_hidden_layers=num_layers,
num_attention_heads=attention_heads,
intermediate_size=4 * hidden_size,
hidden_act='gelu',
use_cache=True
)
model = GPT2LMHeadModel(config)
logging.info("GPT2LMHeadModel initialized successfully.")
elif task == "classification":
config = BertConfig(
vocab_size=vocab_size,
max_position_embeddings=sequence_length,
hidden_size=hidden_size,
num_hidden_layers=num_layers,
num_attention_heads=attention_heads,
intermediate_size=4 * hidden_size,
hidden_act='gelu',
num_labels=2 # Adjust based on your classification task
)
model = BertForSequenceClassification(config)
logging.info("BertForSequenceClassification initialized successfully.")
else:
raise ValueError("Unsupported task type")
return model
except Exception as e:
logging.error(f"Error initializing model: {str(e)}")
raise e
def main():
# Parse arguments
args = parse_arguments()
# Setup logging
log_file = "training.log"
setup_logging(log_file)
logging.info("Training script started.")
# Initialize Hugging Face API
api = HfApi()
# Retrieve the Hugging Face API token from environment variables
hf_token = os.getenv("HF_API_TOKEN")
if not hf_token:
logging.error("HF_API_TOKEN environment variable not set.")
raise ValueError("HF_API_TOKEN environment variable not set.")
# Perform login using the API token
try:
login(token=hf_token)
logging.info("Successfully logged in to Hugging Face Hub.")
except Exception as e:
logging.error(f"Failed to log in to Hugging Face Hub: {str(e)}")
raise e
# Initialize tokenizer
try:
logging.info("Initializing tokenizer...")
if args.task == "generation":
tokenizer = AutoTokenizer.from_pretrained("gpt2")
elif args.task == "classification":
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
else:
raise ValueError("Unsupported task type")
logging.info("Tokenizer initialized successfully.")
except Exception as e:
logging.error(f"Error initializing tokenizer: {str(e)}")
raise e
# Load and prepare dataset
try:
tokenized_datasets = load_and_prepare_dataset(
task=args.task,
dataset_name=args.dataset_name,
tokenizer=tokenizer,
sequence_length=args.sequence_length
)
except Exception as e:
logging.error("Failed to load and prepare dataset.")
raise e
# Initialize model
try:
model = initialize_model(
task=args.task,
model_name=args.model_name,
vocab_size=args.vocab_size,
sequence_length=args.sequence_length,
hidden_size=args.hidden_size,
num_layers=args.num_layers,
attention_heads=args.attention_heads
)
except Exception as e:
logging.error("Failed to initialize model.")
raise e
# Define data collator
if args.task == "generation":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif args.task == "classification":
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
else:
logging.error("Unsupported task type for data collator.")
raise ValueError("Unsupported task type for data collator.")
# Define training arguments
if args.task == "generation":
training_args = TrainingArguments(
output_dir=f"./models/{args.model_name}",
num_train_epochs=3,
per_device_train_batch_size=8,
save_steps=5000,
save_total_limit=2,
logging_steps=500,
learning_rate=5e-4,
remove_unused_columns=False,
push_to_hub=False # We'll handle pushing manually
)
elif args.task == "classification":
training_args = TrainingArguments(
output_dir=f"./models/{args.model_name}",
num_train_epochs=3,
per_device_train_batch_size=16,
evaluation_strategy="epoch",
save_steps=5000,
save_total_limit=2,
logging_steps=500,
learning_rate=5e-5,
remove_unused_columns=False,
push_to_hub=False # We'll handle pushing manually
)
else:
logging.error("Unsupported task type for training arguments.")
raise ValueError("Unsupported task type for training arguments.")
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets,
data_collator=data_collator,
)
# Start training
logging.info("Starting training...")
try:
trainer.train()
logging.info("Training completed successfully.")
except Exception as e:
logging.error(f"Error during training: {str(e)}")
raise e
# Save the final model and tokenizer
try:
trainer.save_model(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
logging.info(f"Model and tokenizer saved to '{training_args.output_dir}'.")
except Exception as e:
logging.error(f"Error saving model or tokenizer: {str(e)}")
raise e
# Push the model to Hugging Face Hub
model_repo = f"{api.whoami(token=hf_token)['name']}/{args.model_name}"
try:
logging.info(f"Pushing model to Hugging Face Hub at '{model_repo}'...")
api.create_repo(repo_id=model_repo, private=False, token=hf_token)
logging.info(f"Repository '{model_repo}' created successfully.")
except Exception as e:
logging.warning(f"Repository might already exist: {str(e)}")
try:
model.push_to_hub(model_repo, use_auth_token=hf_token)
tokenizer.push_to_hub(model_repo, use_auth_token=hf_token)
logging.info(f"Model and tokenizer pushed to Hugging Face Hub at '{model_repo}'.")
except Exception as e:
logging.error(f"Error pushing model to Hugging Face Hub: {str(e)}")
raise e
logging.info("Training script finished successfully.")
if __name__ == "__main__":
main()