Spaces:
Running
Running
import os | |
from datasets import load_dataset | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
TrainingArguments, | |
pipeline, | |
logging, | |
DataCollatorForLanguageModeling, | |
) | |
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
from trl import SFTTrainer | |
import torch | |
import logging | |
from torch.utils.data import DataLoader | |
import multiprocessing | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def preprocess_function(examples, tokenizer): | |
# Format the text | |
texts = [] | |
for i in range(len(examples["text"])): | |
text = examples["text"][i] | |
texts.append(text) | |
# Tokenize the texts with shorter max length | |
tokenized = tokenizer( | |
texts, | |
padding=True, | |
truncation=True, | |
max_length=512, # Reduced from 1024 to 512 | |
return_tensors="pt" | |
) | |
return tokenized | |
def main(): | |
try: | |
# Load dataset | |
logger.info("Loading dataset...") | |
dataset = load_dataset("OpenAssistant/oasst1") | |
# Use a smaller subset for faster training | |
logger.info("Selecting smaller dataset subset...") | |
dataset["train"] = dataset["train"].select(range(2000)) # Reduced to 2k examples | |
# Model and tokenizer setup | |
logger.info("Setting up model and tokenizer...") | |
model_name = "microsoft/phi-2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
# Preprocess dataset | |
logger.info("Preprocessing dataset...") | |
tokenized_dataset = dataset.map( | |
lambda x: preprocess_function(x, tokenizer), | |
batched=True, | |
remove_columns=dataset["train"].column_names, | |
num_proc=4 # Parallel processing for faster preprocessing | |
) | |
# Split dataset into train and eval | |
logger.info("Splitting dataset into train and eval sets...") | |
split_dataset = tokenized_dataset["train"].train_test_split(test_size=0.1, seed=42) | |
train_dataset = split_dataset["train"] | |
eval_dataset = split_dataset["test"] | |
# Configure 4-bit quantization with memory optimizations | |
logger.info("Configuring 4-bit quantization...") | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_storage=torch.float16 | |
) | |
# Load model with quantization and memory optimizations | |
logger.info("Loading model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=bnb_config, | |
device_map="auto", | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
) | |
# Enable gradient checkpointing for memory efficiency | |
model.gradient_checkpointing_enable() | |
model.enable_input_require_grads() | |
# Prepare model for k-bit training | |
logger.info("Preparing model for k-bit training...") | |
model = prepare_model_for_kbit_training(model) | |
# LoRA configuration with optimized parameters | |
logger.info("Configuring LoRA...") | |
lora_config = LoraConfig( | |
r=8, # Reduced from 16 to 8 | |
lora_alpha=16, # Reduced from 32 to 16 | |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
lora_dropout=0.05, | |
bias="none", | |
task_type="CAUSAL_LM" | |
) | |
# Get PEFT model | |
logger.info("Getting PEFT model...") | |
model = get_peft_model(model, lora_config) | |
# Create data collator | |
data_collator = DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, | |
mlm=False | |
) | |
# Training arguments with memory-optimized settings | |
logger.info("Setting up training arguments...") | |
training_args = TrainingArguments( | |
output_dir="./phi2-qlora", | |
num_train_epochs=2, | |
per_device_train_batch_size=4, # Reduced from 16 to 4 | |
per_device_eval_batch_size=4, | |
gradient_accumulation_steps=4, # Increased from 1 to 4 | |
learning_rate=2e-4, | |
fp16=True, | |
logging_steps=5, | |
save_strategy="epoch", | |
evaluation_strategy="epoch", | |
# Additional optimizations | |
dataloader_num_workers=2, # Reduced from 4 to 2 | |
dataloader_pin_memory=True, | |
warmup_ratio=0.05, | |
lr_scheduler_type="cosine", | |
optim="adamw_torch", | |
max_grad_norm=1.0, | |
group_by_length=True, | |
) | |
# Create trainer | |
logger.info("Creating trainer...") | |
trainer = SFTTrainer( | |
model=model, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
tokenizer=tokenizer, | |
args=training_args, | |
data_collator=data_collator, | |
) | |
# Train the model | |
logger.info("Starting training...") | |
trainer.train() | |
# Save the model | |
logger.info("Saving model...") | |
trainer.save_model("./phi2-qlora-final") | |
# Save tokenizer | |
logger.info("Saving tokenizer...") | |
tokenizer.save_pretrained("./phi2-qlora-final") | |
logger.info("Training completed successfully!") | |
except Exception as e: | |
logger.error(f"An error occurred: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
multiprocessing.set_start_method('spawn') | |
main() |