GeminiFan207 commited on
Commit
21102ed
·
verified ·
1 Parent(s): 8362c4c

Create finetune.py

Browse files
Files changed (1) hide show
  1. finetune.py +152 -0
finetune.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import os
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
5
+ from datasets import load_dataset
6
+
7
+ def parse_args():
8
+ parser = argparse.ArgumentParser(description="Fine-tune Charm 15 AI Model")
9
+ parser.add_argument("--model_name", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1",
10
+ help="Base model name or local path (default: Mixtral-8x7B)")
11
+ parser.add_argument("--dataset", type=str, required=True,
12
+ help="Path to training dataset (JSON or text file)")
13
+ parser.add_argument("--eval_dataset", type=str, default=None,
14
+ help="Path to optional validation dataset")
15
+ parser.add_argument("--epochs", type=int, default=3,
16
+ help="Number of training epochs")
17
+ parser.add_argument("--batch_size", type=int, default=1,
18
+ help="Per-device training batch size (lowered for GPU compatibility)")
19
+ parser.add_argument("--lr", type=float, default=5e-5,
20
+ help="Learning rate")
21
+ parser.add_argument("--output_dir", type=str, default="./finetuned_charm15",
22
+ help="Model save directory")
23
+ parser.add_argument("--max_length", type=int, default=512,
24
+ help="Max token length for training")
25
+ return parser.parse_args()
26
+
27
+ def tokenize_function(examples, tokenizer, max_length):
28
+ """Tokenize dataset and prepare labels for causal LM."""
29
+ tokenized = tokenizer(
30
+ examples["text"],
31
+ padding="max_length",
32
+ truncation=True,
33
+ max_length=max_length,
34
+ return_tensors="pt"
35
+ )
36
+ tokenized["labels"] = tokenized["input_ids"].clone()
37
+ return tokenized
38
+
39
+ def main():
40
+ args = parse_args()
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ print(f"Using device: {device}")
43
+
44
+ # Ensure output directory exists
45
+ os.makedirs(args.output_dir, exist_ok=True)
46
+ os.makedirs("./logs", exist_ok=True)
47
+
48
+ # Load tokenizer
49
+ print(f"Loading tokenizer from {args.model_name}...")
50
+ try:
51
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
52
+ if tokenizer.pad_token is None:
53
+ tokenizer.pad_token = tokenizer.eos_token
54
+ tokenizer.pad_token_id = tokenizer.eos_token_id
55
+ except Exception as e:
56
+ print(f"Error loading tokenizer: {e}")
57
+ exit(1)
58
+
59
+ # Load model with optimizations
60
+ print(f"Loading model {args.model_name}...")
61
+ try:
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ args.model_name,
64
+ torch_dtype=torch.bfloat16, # Efficient precision
65
+ device_map="auto", # Spread across GPU/CPU
66
+ low_cpu_mem_usage=True # Reduce RAM
67
+ ).to(device)
68
+ except Exception as e:
69
+ print(f"Error loading model: {e}")
70
+ exit(1)
71
+
72
+ # Load dataset
73
+ print(f"Loading dataset from {args.dataset}...")
74
+ try:
75
+ if args.dataset.endswith(".json"):
76
+ dataset = load_dataset("json", data_files={"train": args.dataset})
77
+ else:
78
+ dataset = load_dataset("text", data_files={"train": args.dataset})
79
+
80
+ eval_dataset = None
81
+ if args.eval_dataset:
82
+ if args.eval_dataset.endswith(".json"):
83
+ eval_dataset = load_dataset("json", data_files={"train": args.eval_dataset})["train"]
84
+ else:
85
+ eval_dataset = load_dataset("text", data_files={"train": args.eval_dataset})["train"]
86
+ except Exception as e:
87
+ print(f"Error loading dataset: {e}")
88
+ exit(1)
89
+
90
+ # Tokenize datasets
91
+ print("Tokenizing dataset...")
92
+ train_dataset = dataset["train"].map(
93
+ lambda x: tokenize_function(x, tokenizer, args.max_length),
94
+ batched=True,
95
+ remove_columns=["text"]
96
+ )
97
+ eval_dataset = eval_dataset.map(
98
+ lambda x: tokenize_function(x, tokenizer, args.max_length),
99
+ batched=True,
100
+ remove_columns=["text"]
101
+ ) if args.eval_dataset else None
102
+
103
+ # Training arguments
104
+ training_args = TrainingArguments(
105
+ output_dir=args.output_dir,
106
+ per_device_train_batch_size=args.batch_size,
107
+ per_device_eval_batch_size=args.batch_size,
108
+ num_train_epochs=args.epochs,
109
+ learning_rate=args.lr,
110
+ gradient_accumulation_steps=8, # Effective batch size = 8
111
+ bf16=True, # Match dtype
112
+ fp16=False,
113
+ save_total_limit=2,
114
+ save_steps=500,
115
+ logging_dir="./logs",
116
+ logging_steps=100,
117
+ report_to="none",
118
+ evaluation_strategy="epoch" if eval_dataset else "no",
119
+ save_strategy="epoch",
120
+ load_best_model_at_end=bool(eval_dataset),
121
+ metric_for_best_model="loss"
122
+ )
123
+
124
+ # Initialize Trainer
125
+ trainer = Trainer(
126
+ model=model,
127
+ args=training_args,
128
+ train_dataset=train_dataset,
129
+ eval_dataset=eval_dataset,
130
+ tokenizer=tokenizer
131
+ )
132
+
133
+ # Train
134
+ print("Starting fine-tuning...")
135
+ try:
136
+ trainer.train()
137
+ except RuntimeError as e:
138
+ print(f"Training failed: {e} (Try reducing batch_size or max_length)")
139
+ exit(1)
140
+
141
+ # Save
142
+ print(f"Saving fine-tuned model to {args.output_dir}")
143
+ trainer.save_model(args.output_dir)
144
+ tokenizer.save_pretrained(args.output_dir)
145
+
146
+ # Cleanup
147
+ del model
148
+ torch.cuda.empty_cache()
149
+ print("Training complete. Memory cleared.")
150
+
151
+ if __name__ == "__main__":
152
+ main()