import os from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline, logging, DataCollatorForLanguageModeling, ) from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from trl import SFTTrainer import torch import logging from torch.utils.data import DataLoader import multiprocessing # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def preprocess_function(examples, tokenizer): # Format the text texts = [] for i in range(len(examples["text"])): text = examples["text"][i] texts.append(text) # Tokenize the texts with shorter max length tokenized = tokenizer( texts, padding=True, truncation=True, max_length=512, # Reduced from 1024 to 512 return_tensors="pt" ) return tokenized def main(): try: # Load dataset logger.info("Loading dataset...") dataset = load_dataset("OpenAssistant/oasst1") # Use a smaller subset for faster training logger.info("Selecting smaller dataset subset...") dataset["train"] = dataset["train"].select(range(2000)) # Reduced to 2k examples # Model and tokenizer setup logger.info("Setting up model and tokenizer...") model_name = "microsoft/phi-2" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token # Preprocess dataset logger.info("Preprocessing dataset...") tokenized_dataset = dataset.map( lambda x: preprocess_function(x, tokenizer), batched=True, remove_columns=dataset["train"].column_names, num_proc=4 # Parallel processing for faster preprocessing ) # Split dataset into train and eval logger.info("Splitting dataset into train and eval sets...") split_dataset = tokenized_dataset["train"].train_test_split(test_size=0.1, seed=42) train_dataset = split_dataset["train"] eval_dataset = split_dataset["test"] # Configure 4-bit quantization with memory optimizations logger.info("Configuring 4-bit quantization...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_storage=torch.float16 ) # Load model with quantization and memory optimizations logger.info("Loading model...") model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16, low_cpu_mem_usage=True ) # Enable gradient checkpointing for memory efficiency model.gradient_checkpointing_enable() model.enable_input_require_grads() # Prepare model for k-bit training logger.info("Preparing model for k-bit training...") model = prepare_model_for_kbit_training(model) # LoRA configuration with optimized parameters logger.info("Configuring LoRA...") lora_config = LoraConfig( r=8, # Reduced from 16 to 8 lora_alpha=16, # Reduced from 32 to 16 target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) # Get PEFT model logger.info("Getting PEFT model...") model = get_peft_model(model, lora_config) # Create data collator data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False ) # Training arguments with memory-optimized settings logger.info("Setting up training arguments...") training_args = TrainingArguments( output_dir="./phi2-qlora", num_train_epochs=2, per_device_train_batch_size=4, # Reduced from 16 to 4 per_device_eval_batch_size=4, gradient_accumulation_steps=4, # Increased from 1 to 4 learning_rate=2e-4, fp16=True, logging_steps=5, save_strategy="epoch", evaluation_strategy="epoch", # Additional optimizations dataloader_num_workers=2, # Reduced from 4 to 2 dataloader_pin_memory=True, warmup_ratio=0.05, lr_scheduler_type="cosine", optim="adamw_torch", max_grad_norm=1.0, group_by_length=True, ) # Create trainer logger.info("Creating trainer...") trainer = SFTTrainer( model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, args=training_args, data_collator=data_collator, ) # Train the model logger.info("Starting training...") trainer.train() # Save the model logger.info("Saving model...") trainer.save_model("./phi2-qlora-final") # Save tokenizer logger.info("Saving tokenizer...") tokenizer.save_pretrained("./phi2-qlora-final") logger.info("Training completed successfully!") except Exception as e: logger.error(f"An error occurred: {str(e)}") raise if __name__ == "__main__": multiprocessing.set_start_method('spawn') main()