Phi2_Qlora / train.py
padmanabhbosamia's picture
Upload 12 files
bfe5d0e verified
raw
history blame
5.95 kB
import os
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
pipeline,
logging,
DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
import torch
import logging
from torch.utils.data import DataLoader
import multiprocessing
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def preprocess_function(examples, tokenizer):
# Format the text
texts = []
for i in range(len(examples["text"])):
text = examples["text"][i]
texts.append(text)
# Tokenize the texts with shorter max length
tokenized = tokenizer(
texts,
padding=True,
truncation=True,
max_length=512, # Reduced from 1024 to 512
return_tensors="pt"
)
return tokenized
def main():
try:
# Load dataset
logger.info("Loading dataset...")
dataset = load_dataset("OpenAssistant/oasst1")
# Use a smaller subset for faster training
logger.info("Selecting smaller dataset subset...")
dataset["train"] = dataset["train"].select(range(2000)) # Reduced to 2k examples
# Model and tokenizer setup
logger.info("Setting up model and tokenizer...")
model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Preprocess dataset
logger.info("Preprocessing dataset...")
tokenized_dataset = dataset.map(
lambda x: preprocess_function(x, tokenizer),
batched=True,
remove_columns=dataset["train"].column_names,
num_proc=4 # Parallel processing for faster preprocessing
)
# Split dataset into train and eval
logger.info("Splitting dataset into train and eval sets...")
split_dataset = tokenized_dataset["train"].train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
# Configure 4-bit quantization with memory optimizations
logger.info("Configuring 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.float16
)
# Load model with quantization and memory optimizations
logger.info("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
# Enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
# Prepare model for k-bit training
logger.info("Preparing model for k-bit training...")
model = prepare_model_for_kbit_training(model)
# LoRA configuration with optimized parameters
logger.info("Configuring LoRA...")
lora_config = LoraConfig(
r=8, # Reduced from 16 to 8
lora_alpha=16, # Reduced from 32 to 16
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Get PEFT model
logger.info("Getting PEFT model...")
model = get_peft_model(model, lora_config)
# Create data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Training arguments with memory-optimized settings
logger.info("Setting up training arguments...")
training_args = TrainingArguments(
output_dir="./phi2-qlora",
num_train_epochs=2,
per_device_train_batch_size=4, # Reduced from 16 to 4
per_device_eval_batch_size=4,
gradient_accumulation_steps=4, # Increased from 1 to 4
learning_rate=2e-4,
fp16=True,
logging_steps=5,
save_strategy="epoch",
evaluation_strategy="epoch",
# Additional optimizations
dataloader_num_workers=2, # Reduced from 4 to 2
dataloader_pin_memory=True,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
optim="adamw_torch",
max_grad_norm=1.0,
group_by_length=True,
)
# Create trainer
logger.info("Creating trainer...")
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
args=training_args,
data_collator=data_collator,
)
# Train the model
logger.info("Starting training...")
trainer.train()
# Save the model
logger.info("Saving model...")
trainer.save_model("./phi2-qlora-final")
# Save tokenizer
logger.info("Saving tokenizer...")
tokenizer.save_pretrained("./phi2-qlora-final")
logger.info("Training completed successfully!")
except Exception as e:
logger.error(f"An error occurred: {str(e)}")
raise
if __name__ == "__main__":
multiprocessing.set_start_method('spawn')
main()