Spaces:
Sleeping
Sleeping
# train_model.py (Training Script) | |
import argparse | |
from transformers import ( | |
GPT2Config, | |
GPT2LMHeadModel, | |
BertConfig, | |
BertForSequenceClassification, | |
Trainer, | |
TrainingArguments, | |
AutoTokenizer, | |
DataCollatorForLanguageModeling, | |
DataCollatorWithPadding, | |
) | |
from datasets import load_dataset | |
import torch | |
import os | |
from huggingface_hub import login, HfApi, HfFolder | |
import logging | |
def setup_logging(log_file_path): | |
""" | |
Sets up logging to both console and a file. | |
""" | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
# Create handlers | |
c_handler = logging.StreamHandler() | |
f_handler = logging.FileHandler(log_file_path) | |
c_handler.setLevel(logging.INFO) | |
f_handler.setLevel(logging.INFO) | |
# Create formatters and add to handlers | |
c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
c_handler.setFormatter(c_format) | |
f_handler.setFormatter(f_format) | |
# Add handlers to the logger | |
logger.addHandler(c_handler) | |
logger.addHandler(f_handler) | |
def parse_arguments(): | |
""" | |
Parses command-line arguments. | |
""" | |
parser = argparse.ArgumentParser(description="Train a custom LLM.") | |
parser.add_argument("--task", type=str, required=True, choices=["generation", "classification"], | |
help="Task type: 'generation' or 'classification'") | |
parser.add_argument("--model_name", type=str, required=True, help="Name of the model") | |
parser.add_argument("--dataset_name", type=str, required=True, help="Name of the Hugging Face dataset (e.g., 'wikitext/wikitext-2-raw-v1')") | |
parser.add_argument("--num_layers", type=int, default=12, help="Number of hidden layers") | |
parser.add_argument("--attention_heads", type=int, default=1, help="Number of attention heads") | |
parser.add_argument("--hidden_size", type=int, default=64, help="Hidden size of the model") | |
parser.add_argument("--vocab_size", type=int, default=30000, help="Vocabulary size") | |
parser.add_argument("--sequence_length", type=int, default=512, help="Maximum sequence length") | |
args = parser.parse_args() | |
return args | |
def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length): | |
""" | |
Loads and tokenizes the dataset based on the task. | |
""" | |
logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...") | |
try: | |
if task == "generation": | |
# Check if dataset_name includes config | |
if '/' in dataset_name: | |
dataset, config = dataset_name.split('/', 1) | |
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train[:1%]', use_auth_token=True) | |
else: | |
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train[:1%]', use_auth_token=True) | |
logging.info("Dataset loaded successfully for generation task.") | |
def tokenize_function(examples): | |
return tokenizer(examples['text'], truncation=True, max_length=sequence_length) | |
elif task == "classification": | |
if '/' in dataset_name: | |
dataset, config = dataset_name.split('/', 1) | |
dataset = load_dataset("stanfordnlp/imdb", split='train[:1%]', use_auth_token=True) | |
else: | |
dataset = load_dataset("stanfordnlp/imdb", split='train[:1%]', use_auth_token=True) | |
logging.info("Dataset loaded successfully for classification task.") | |
# Assuming the dataset has 'text' and 'label' columns | |
def tokenize_function(examples): | |
return tokenizer(examples['text'], truncation=True, max_length=sequence_length) | |
else: | |
raise ValueError("Unsupported task type") | |
# Shuffle and select a subset | |
tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True) | |
logging.info("Dataset tokenization complete.") | |
return tokenized_datasets | |
except Exception as e: | |
logging.error(f"Error loading or tokenizing dataset: {str(e)}") | |
raise e | |
def initialize_model(task, model_name, vocab_size, sequence_length, hidden_size, num_layers, attention_heads): | |
""" | |
Initializes the model configuration and model based on the task. | |
""" | |
logging.info(f"Initializing model for task '{task}'...") | |
try: | |
if task == "generation": | |
config = GPT2Config( | |
vocab_size=vocab_size, | |
n_positions=sequence_length, | |
n_ctx=sequence_length, | |
n_embd=hidden_size, | |
num_hidden_layers=num_layers, | |
num_attention_heads=attention_heads, | |
intermediate_size=4 * hidden_size, | |
hidden_act='gelu', | |
use_cache=True | |
) | |
model = GPT2LMHeadModel(config) | |
logging.info("GPT2LMHeadModel initialized successfully.") | |
elif task == "classification": | |
config = BertConfig( | |
vocab_size=vocab_size, | |
max_position_embeddings=sequence_length, | |
hidden_size=hidden_size, | |
num_hidden_layers=num_layers, | |
num_attention_heads=attention_heads, | |
intermediate_size=4 * hidden_size, | |
hidden_act='gelu', | |
num_labels=2 # Adjust based on your classification task | |
) | |
model = BertForSequenceClassification(config) | |
logging.info("BertForSequenceClassification initialized successfully.") | |
else: | |
raise ValueError("Unsupported task type") | |
return model | |
except Exception as e: | |
logging.error(f"Error initializing model: {str(e)}") | |
raise e | |
def main(): | |
# Parse arguments | |
args = parse_arguments() | |
# Setup logging | |
log_file = "training.log" | |
setup_logging(log_file) | |
logging.info("Training script started.") | |
# Initialize Hugging Face API | |
api = HfApi() | |
# Retrieve the Hugging Face API token from environment variables | |
hf_token = os.getenv("HF_API_TOKEN") | |
if not hf_token: | |
logging.error("HF_API_TOKEN environment variable not set.") | |
raise ValueError("HF_API_TOKEN environment variable not set.") | |
# Perform login using the API token | |
try: | |
login(token=hf_token) | |
logging.info("Successfully logged in to Hugging Face Hub.") | |
except Exception as e: | |
logging.error(f"Failed to log in to Hugging Face Hub: {str(e)}") | |
raise e | |
# Initialize tokenizer | |
try: | |
logging.info("Initializing tokenizer...") | |
if args.task == "generation": | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
elif args.task == "classification": | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
else: | |
raise ValueError("Unsupported task type") | |
logging.info("Tokenizer initialized successfully.") | |
except Exception as e: | |
logging.error(f"Error initializing tokenizer: {str(e)}") | |
raise e | |
# Load and prepare dataset | |
try: | |
tokenized_datasets = load_and_prepare_dataset( | |
task=args.task, | |
dataset_name=args.dataset_name, | |
tokenizer=tokenizer, | |
sequence_length=args.sequence_length | |
) | |
except Exception as e: | |
logging.error("Failed to load and prepare dataset.") | |
raise e | |
# Initialize model | |
try: | |
model = initialize_model( | |
task=args.task, | |
model_name=args.model_name, | |
vocab_size=args.vocab_size, | |
sequence_length=args.sequence_length, | |
hidden_size=args.hidden_size, | |
num_layers=args.num_layers, | |
attention_heads=args.attention_heads | |
) | |
except Exception as e: | |
logging.error("Failed to initialize model.") | |
raise e | |
# Define data collator | |
if args.task == "generation": | |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
elif args.task == "classification": | |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
else: | |
logging.error("Unsupported task type for data collator.") | |
raise ValueError("Unsupported task type for data collator.") | |
# Define training arguments | |
if args.task == "generation": | |
training_args = TrainingArguments( | |
output_dir=f"./models/{args.model_name}", | |
num_train_epochs=3, | |
per_device_train_batch_size=8, | |
save_steps=5000, | |
save_total_limit=2, | |
logging_steps=500, | |
learning_rate=5e-4, | |
remove_unused_columns=False, | |
push_to_hub=False # We'll handle pushing manually | |
) | |
elif args.task == "classification": | |
training_args = TrainingArguments( | |
output_dir=f"./models/{args.model_name}", | |
num_train_epochs=3, | |
per_device_train_batch_size=16, | |
evaluation_strategy="epoch", | |
save_steps=5000, | |
save_total_limit=2, | |
logging_steps=500, | |
learning_rate=5e-5, | |
remove_unused_columns=False, | |
push_to_hub=False # We'll handle pushing manually | |
) | |
else: | |
logging.error("Unsupported task type for training arguments.") | |
raise ValueError("Unsupported task type for training arguments.") | |
# Initialize Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_datasets, | |
data_collator=data_collator, | |
) | |
# Start training | |
logging.info("Starting training...") | |
try: | |
trainer.train() | |
logging.info("Training completed successfully.") | |
except Exception as e: | |
logging.error(f"Error during training: {str(e)}") | |
raise e | |
# Save the final model and tokenizer | |
try: | |
trainer.save_model(training_args.output_dir) | |
tokenizer.save_pretrained(training_args.output_dir) | |
logging.info(f"Model and tokenizer saved to '{training_args.output_dir}'.") | |
except Exception as e: | |
logging.error(f"Error saving model or tokenizer: {str(e)}") | |
raise e | |
# Push the model to Hugging Face Hub | |
model_repo = f"{api.whoami(token=hf_token)['name']}/{args.model_name}" | |
try: | |
logging.info(f"Pushing model to Hugging Face Hub at '{model_repo}'...") | |
api.create_repo(repo_id=model_repo, private=False, token=hf_token) | |
logging.info(f"Repository '{model_repo}' created successfully.") | |
except Exception as e: | |
logging.warning(f"Repository might already exist: {str(e)}") | |
try: | |
model.push_to_hub(model_repo, use_auth_token=hf_token) | |
tokenizer.push_to_hub(model_repo, use_auth_token=hf_token) | |
logging.info(f"Model and tokenizer pushed to Hugging Face Hub at '{model_repo}'.") | |
except Exception as e: | |
logging.error(f"Error pushing model to Hugging Face Hub: {str(e)}") | |
raise e | |
logging.info("Training script finished successfully.") | |
if __name__ == "__main__": | |
main() | |