Charm_15 / finetune.py
GeminiFan207's picture
Create finetune.py
21102ed verified
import torch
import argparse
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset
def parse_args():
parser = argparse.ArgumentParser(description="Fine-tune Charm 15 AI Model")
parser.add_argument("--model_name", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1",
help="Base model name or local path (default: Mixtral-8x7B)")
parser.add_argument("--dataset", type=str, required=True,
help="Path to training dataset (JSON or text file)")
parser.add_argument("--eval_dataset", type=str, default=None,
help="Path to optional validation dataset")
parser.add_argument("--epochs", type=int, default=3,
help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=1,
help="Per-device training batch size (lowered for GPU compatibility)")
parser.add_argument("--lr", type=float, default=5e-5,
help="Learning rate")
parser.add_argument("--output_dir", type=str, default="./finetuned_charm15",
help="Model save directory")
parser.add_argument("--max_length", type=int, default=512,
help="Max token length for training")
return parser.parse_args()
def tokenize_function(examples, tokenizer, max_length):
"""Tokenize dataset and prepare labels for causal LM."""
tokenized = tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
return tokenized
def main():
args = parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Ensure output directory exists
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs("./logs", exist_ok=True)
# Load tokenizer
print(f"Loading tokenizer from {args.model_name}...")
try:
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
except Exception as e:
print(f"Error loading tokenizer: {e}")
exit(1)
# Load model with optimizations
print(f"Loading model {args.model_name}...")
try:
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.bfloat16, # Efficient precision
device_map="auto", # Spread across GPU/CPU
low_cpu_mem_usage=True # Reduce RAM
).to(device)
except Exception as e:
print(f"Error loading model: {e}")
exit(1)
# Load dataset
print(f"Loading dataset from {args.dataset}...")
try:
if args.dataset.endswith(".json"):
dataset = load_dataset("json", data_files={"train": args.dataset})
else:
dataset = load_dataset("text", data_files={"train": args.dataset})
eval_dataset = None
if args.eval_dataset:
if args.eval_dataset.endswith(".json"):
eval_dataset = load_dataset("json", data_files={"train": args.eval_dataset})["train"]
else:
eval_dataset = load_dataset("text", data_files={"train": args.eval_dataset})["train"]
except Exception as e:
print(f"Error loading dataset: {e}")
exit(1)
# Tokenize datasets
print("Tokenizing dataset...")
train_dataset = dataset["train"].map(
lambda x: tokenize_function(x, tokenizer, args.max_length),
batched=True,
remove_columns=["text"]
)
eval_dataset = eval_dataset.map(
lambda x: tokenize_function(x, tokenizer, args.max_length),
batched=True,
remove_columns=["text"]
) if args.eval_dataset else None
# Training arguments
training_args = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.epochs,
learning_rate=args.lr,
gradient_accumulation_steps=8, # Effective batch size = 8
bf16=True, # Match dtype
fp16=False,
save_total_limit=2,
save_steps=500,
logging_dir="./logs",
logging_steps=100,
report_to="none",
evaluation_strategy="epoch" if eval_dataset else "no",
save_strategy="epoch",
load_best_model_at_end=bool(eval_dataset),
metric_for_best_model="loss"
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer
)
# Train
print("Starting fine-tuning...")
try:
trainer.train()
except RuntimeError as e:
print(f"Training failed: {e} (Try reducing batch_size or max_length)")
exit(1)
# Save
print(f"Saving fine-tuned model to {args.output_dir}")
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# Cleanup
del model
torch.cuda.empty_cache()
print("Training complete. Memory cleared.")
if __name__ == "__main__":
main()