Charm_15 / train.py
GeminiFan207's picture
Update train.py
4a6099f verified
raw
history blame
2.87 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset
import os
# Model and tokenizer setup
MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1"
OUTPUT_DIR = "./mixtral_finetuned"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # Fallback if undefined
tokenizer.pad_token_id = tokenizer.eos_token_id
# Load model with optimizations
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True
)
# Load dataset (local text files)
try:
dataset = load_dataset("text", data_files={"train": "train.txt", "validation": "val.txt"})
except FileNotFoundError:
print("Error: train.txt or val.txt not found. Please provide valid files.")
exit(1)
# Tokenize dataset
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=512, # Adjust to 2048 or 4096 if needed
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
return tokenized
tokenized_datasets = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"]
)
# Split dataset with validation check
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"] if "validation" in tokenized_datasets else None
if not train_dataset or (eval_dataset and len(eval_dataset) == 0):
print("Error: Empty training or validation dataset.")
exit(1)
# Define training arguments
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
evaluation_strategy="epoch" if eval_dataset else "no", # Skip eval if no validation
per_device_train_batch_size=1, # Lowered for smaller GPUs; adjust up if possible
per_device_eval_batch_size=1,
num_train_epochs=3,
learning_rate=2e-5,
weight_decay=0.01,
gradient_accumulation_steps=8, # Effective batch size = 8
bf16=True,
fp16=False,
save_strategy="epoch",
save_total_limit=2,
logging_dir="./logs",
logging_steps=10,
load_best_model_at_end=bool(eval_dataset), # Only if eval exists
metric_for_best_model="loss",
report_to="none"
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Train the model
try:
trainer.train()
except RuntimeError as e:
print(f"Training failed: {e} (Likely OOM—reduce batch size or max_length)")
exit(1)
# Save locally
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
# Clean up
del model
torch.cuda.empty_cache()
print(f"Model and tokenizer saved to {OUTPUT_DIR}")