|
import torch |
|
import argparse |
|
import os |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer |
|
from datasets import load_dataset |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Fine-tune Charm 15 AI Model") |
|
parser.add_argument("--model_name", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
help="Base model name or local path (default: Mixtral-8x7B)") |
|
parser.add_argument("--dataset", type=str, required=True, |
|
help="Path to training dataset (JSON or text file)") |
|
parser.add_argument("--eval_dataset", type=str, default=None, |
|
help="Path to optional validation dataset") |
|
parser.add_argument("--epochs", type=int, default=3, |
|
help="Number of training epochs") |
|
parser.add_argument("--batch_size", type=int, default=1, |
|
help="Per-device training batch size (lowered for GPU compatibility)") |
|
parser.add_argument("--lr", type=float, default=5e-5, |
|
help="Learning rate") |
|
parser.add_argument("--output_dir", type=str, default="./finetuned_charm15", |
|
help="Model save directory") |
|
parser.add_argument("--max_length", type=int, default=512, |
|
help="Max token length for training") |
|
return parser.parse_args() |
|
|
|
def tokenize_function(examples, tokenizer, max_length): |
|
"""Tokenize dataset and prepare labels for causal LM.""" |
|
tokenized = tokenizer( |
|
examples["text"], |
|
padding="max_length", |
|
truncation=True, |
|
max_length=max_length, |
|
return_tensors="pt" |
|
) |
|
tokenized["labels"] = tokenized["input_ids"].clone() |
|
return tokenized |
|
|
|
def main(): |
|
args = parse_args() |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
os.makedirs("./logs", exist_ok=True) |
|
|
|
|
|
print(f"Loading tokenizer from {args.model_name}...") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
except Exception as e: |
|
print(f"Error loading tokenizer: {e}") |
|
exit(1) |
|
|
|
|
|
print(f"Loading model {args.model_name}...") |
|
try: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
args.model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
low_cpu_mem_usage=True |
|
).to(device) |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
exit(1) |
|
|
|
|
|
print(f"Loading dataset from {args.dataset}...") |
|
try: |
|
if args.dataset.endswith(".json"): |
|
dataset = load_dataset("json", data_files={"train": args.dataset}) |
|
else: |
|
dataset = load_dataset("text", data_files={"train": args.dataset}) |
|
|
|
eval_dataset = None |
|
if args.eval_dataset: |
|
if args.eval_dataset.endswith(".json"): |
|
eval_dataset = load_dataset("json", data_files={"train": args.eval_dataset})["train"] |
|
else: |
|
eval_dataset = load_dataset("text", data_files={"train": args.eval_dataset})["train"] |
|
except Exception as e: |
|
print(f"Error loading dataset: {e}") |
|
exit(1) |
|
|
|
|
|
print("Tokenizing dataset...") |
|
train_dataset = dataset["train"].map( |
|
lambda x: tokenize_function(x, tokenizer, args.max_length), |
|
batched=True, |
|
remove_columns=["text"] |
|
) |
|
eval_dataset = eval_dataset.map( |
|
lambda x: tokenize_function(x, tokenizer, args.max_length), |
|
batched=True, |
|
remove_columns=["text"] |
|
) if args.eval_dataset else None |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=args.output_dir, |
|
per_device_train_batch_size=args.batch_size, |
|
per_device_eval_batch_size=args.batch_size, |
|
num_train_epochs=args.epochs, |
|
learning_rate=args.lr, |
|
gradient_accumulation_steps=8, |
|
bf16=True, |
|
fp16=False, |
|
save_total_limit=2, |
|
save_steps=500, |
|
logging_dir="./logs", |
|
logging_steps=100, |
|
report_to="none", |
|
evaluation_strategy="epoch" if eval_dataset else "no", |
|
save_strategy="epoch", |
|
load_best_model_at_end=bool(eval_dataset), |
|
metric_for_best_model="loss" |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
tokenizer=tokenizer |
|
) |
|
|
|
|
|
print("Starting fine-tuning...") |
|
try: |
|
trainer.train() |
|
except RuntimeError as e: |
|
print(f"Training failed: {e} (Try reducing batch_size or max_length)") |
|
exit(1) |
|
|
|
|
|
print(f"Saving fine-tuned model to {args.output_dir}") |
|
trainer.save_model(args.output_dir) |
|
tokenizer.save_pretrained(args.output_dir) |
|
|
|
|
|
del model |
|
torch.cuda.empty_cache() |
|
print("Training complete. Memory cleared.") |
|
|
|
if __name__ == "__main__": |
|
main() |