File size: 5,505 Bytes
21102ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import argparse
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset

def parse_args():
    parser = argparse.ArgumentParser(description="Fine-tune Charm 15 AI Model")
    parser.add_argument("--model_name", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1", 
                        help="Base model name or local path (default: Mixtral-8x7B)")
    parser.add_argument("--dataset", type=str, required=True, 
                        help="Path to training dataset (JSON or text file)")
    parser.add_argument("--eval_dataset", type=str, default=None, 
                        help="Path to optional validation dataset")
    parser.add_argument("--epochs", type=int, default=3, 
                        help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=1, 
                        help="Per-device training batch size (lowered for GPU compatibility)")
    parser.add_argument("--lr", type=float, default=5e-5, 
                        help="Learning rate")
    parser.add_argument("--output_dir", type=str, default="./finetuned_charm15", 
                        help="Model save directory")
    parser.add_argument("--max_length", type=int, default=512, 
                        help="Max token length for training")
    return parser.parse_args()

def tokenize_function(examples, tokenizer, max_length):
    """Tokenize dataset and prepare labels for causal LM."""
    tokenized = tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    tokenized["labels"] = tokenized["input_ids"].clone()
    return tokenized

def main():
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Ensure output directory exists
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs("./logs", exist_ok=True)

    # Load tokenizer
    print(f"Loading tokenizer from {args.model_name}...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
    except Exception as e:
        print(f"Error loading tokenizer: {e}")
        exit(1)

    # Load model with optimizations
    print(f"Loading model {args.model_name}...")
    try:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name,
            torch_dtype=torch.bfloat16,    # Efficient precision
            device_map="auto",             # Spread across GPU/CPU
            low_cpu_mem_usage=True         # Reduce RAM
        ).to(device)
    except Exception as e:
        print(f"Error loading model: {e}")
        exit(1)

    # Load dataset
    print(f"Loading dataset from {args.dataset}...")
    try:
        if args.dataset.endswith(".json"):
            dataset = load_dataset("json", data_files={"train": args.dataset})
        else:
            dataset = load_dataset("text", data_files={"train": args.dataset})
        
        eval_dataset = None
        if args.eval_dataset:
            if args.eval_dataset.endswith(".json"):
                eval_dataset = load_dataset("json", data_files={"train": args.eval_dataset})["train"]
            else:
                eval_dataset = load_dataset("text", data_files={"train": args.eval_dataset})["train"]
    except Exception as e:
        print(f"Error loading dataset: {e}")
        exit(1)

    # Tokenize datasets
    print("Tokenizing dataset...")
    train_dataset = dataset["train"].map(
        lambda x: tokenize_function(x, tokenizer, args.max_length),
        batched=True,
        remove_columns=["text"]
    )
    eval_dataset = eval_dataset.map(
        lambda x: tokenize_function(x, tokenizer, args.max_length),
        batched=True,
        remove_columns=["text"]
    ) if args.eval_dataset else None

    # Training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        num_train_epochs=args.epochs,
        learning_rate=args.lr,
        gradient_accumulation_steps=8,     # Effective batch size = 8
        bf16=True,                         # Match dtype
        fp16=False,
        save_total_limit=2,
        save_steps=500,
        logging_dir="./logs",
        logging_steps=100,
        report_to="none",
        evaluation_strategy="epoch" if eval_dataset else "no",
        save_strategy="epoch",
        load_best_model_at_end=bool(eval_dataset),
        metric_for_best_model="loss"
    )

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer
    )

    # Train
    print("Starting fine-tuning...")
    try:
        trainer.train()
    except RuntimeError as e:
        print(f"Training failed: {e} (Try reducing batch_size or max_length)")
        exit(1)

    # Save
    print(f"Saving fine-tuned model to {args.output_dir}")
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)

    # Cleanup
    del model
    torch.cuda.empty_cache()
    print("Training complete. Memory cleared.")

if __name__ == "__main__":
    main()