burtenshaw HF Staff commited on
Commit
6c691d0
·
verified ·
1 Parent(s): 88c73d2

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +211 -0
train.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "datasets",
5
+ # "httpx",
6
+ # "huggingface-hub",
7
+ # "setuptools",
8
+ # "transformers",
9
+ # "torch",
10
+ # "accelerate",
11
+ # "trl",
12
+ # "peft",
13
+ # "wandb",
14
+ # "bitsandbytes",
15
+ # "torchvision",
16
+ # "torchaudio",
17
+ # ]
18
+ #
19
+
20
+ # ///
21
+
22
+
23
+ """## Import libraries"""
24
+
25
+ import torch
26
+ from datasets import load_dataset
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
28
+ from trl import SFTConfig, SFTTrainer, setup_chat_format
29
+ from peft import LoraConfig
30
+
31
+ """# Load Dataset"""
32
+
33
+ dataset_name = "allenai/tulu-3-sft-personas-code" # Example dataset
34
+
35
+ # Load dataset
36
+ dataset = load_dataset(dataset_name, split="train")
37
+ print(f"Dataset loaded: {dataset}")
38
+
39
+ # Let's look at a sample
40
+ print("\nSample data:")
41
+ print(dataset[0])
42
+
43
+ dataset = dataset.remove_columns("prompt")
44
+ dataset = dataset.train_test_split(test_size=0.2)
45
+
46
+ print(
47
+ f"Train Samples: {len(dataset['train'])}\nTest Samples: {len(dataset['test'])}"
48
+ )
49
+
50
+ """## Configuration
51
+
52
+ Set up the configuration parameters for the fine-tuning process.
53
+ """
54
+
55
+ # Model configuration
56
+ model_name = "Qwen/Qwen3-30B-A3B" # You can change this to any model you want to fine-tune
57
+
58
+ # # Other compatible Qwen3 models
59
+ # model_name = "Qwen/Qwen3-32B"
60
+ # model_name = "Qwen/Qwen3-14B"
61
+ # model_name = "Qwen/Qwen3-8B"
62
+ # model_name = "Qwen/Qwen3-4B"
63
+ # model_name = "Qwen/Qwen3-1.7B"
64
+ # model_name = "Qwen/Qwen3-0.6B"
65
+
66
+ # Training configuration
67
+ output_dir = "./output/sft-model"
68
+ num_train_epochs = 1
69
+ per_device_train_batch_size = 1
70
+ gradient_accumulation_steps = 1
71
+ learning_rate = 2e-4 if use_peft else 2e-5 # Higher learning rate for PEFT
72
+
73
+ """## Load model and tokenizer"""
74
+
75
+ # specify how to quantize the model
76
+ quantization_config = BitsAndBytesConfig(
77
+ load_in_4bit=True,
78
+ bnb_4bit_quant_type="nf4",
79
+ bnb_4bit_use_double_quant=True,
80
+ )
81
+
82
+ # Load model
83
+ model = AutoModelForCausalLM.from_pretrained(
84
+ model_name,
85
+ torch_dtype=torch.bfloat16,
86
+ use_cache=False, # Disable KV cache during training
87
+ device_map="auto",
88
+ quantization_config=quantization_config
89
+ )
90
+
91
+ # Load tokenizer
92
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
93
+
94
+ # # Set up chat formatting (if the model doesn't have a chat template)
95
+ # if tokenizer.chat_template is None:
96
+ # model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
97
+
98
+ # # Set padding token
99
+ # if tokenizer.pad_token is None:
100
+ # tokenizer.pad_token = tokenizer.eos_token
101
+
102
+ """## Configure PEFT (if enabled)"""
103
+
104
+ # Set up PEFT configuration if enabled
105
+ peft_config = LoraConfig(
106
+ r=32, # Rank
107
+ lora_alpha=16, # Alpha parameter for LoRA scaling
108
+ lora_dropout=0.05,
109
+ bias="none",
110
+ task_type="CAUSAL_LM",
111
+ target_modules="all-linear",
112
+ )
113
+
114
+ """## Configure SFT Trainer"""
115
+
116
+ # Training arguments
117
+ training_args = SFTConfig(
118
+ output_dir=output_dir,
119
+ num_train_epochs=num_train_epochs,
120
+ per_device_train_batch_size=per_device_train_batch_size,
121
+ gradient_accumulation_steps=gradient_accumulation_steps,
122
+ learning_rate=learning_rate,
123
+ gradient_checkpointing=True,
124
+ logging_steps=25,
125
+ save_strategy="epoch",
126
+ optim="adamw_torch",
127
+ lr_scheduler_type="cosine",
128
+ warmup_ratio=0.1,
129
+ max_length=1024,
130
+ packing=True, # Enable packing to increase training efficiency
131
+ eos_token=tokenizer.eos_token,
132
+ bf16=True,
133
+ fp16=False,
134
+ max_steps=1000,
135
+ report_to="wandb", # Disable reporting to avoid wandb prompts
136
+ )
137
+
138
+ """## Initialize and run the SFT Trainer"""
139
+
140
+ # Create SFT Trainer
141
+ trainer = SFTTrainer(
142
+ model=model,
143
+ args=training_args,
144
+ train_dataset=dataset["train"],
145
+ eval_dataset=dataset["test"] if "test" in dataset else None,
146
+ peft_config=peft_config,
147
+ processing_class=tokenizer,
148
+ )
149
+
150
+ # Train the model
151
+ trainer.train()
152
+
153
+ """## Save the fine-tuned model"""
154
+
155
+ # Save the model
156
+ trainer.save_model(output_dir)
157
+
158
+ """## Test the fine-tuned model"""
159
+
160
+ from peft import PeftModel, PeftConfig
161
+
162
+ # Load the base model
163
+ base_model = AutoModelForCausalLM.from_pretrained(
164
+ model_name, trust_remote_code=True, torch_dtype=torch.bfloat16
165
+ )
166
+
167
+ # Load the fine-tuned PEFT model
168
+ model = PeftModel.from_pretrained(base_model, output_dir)
169
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
170
+ # Test the model with an example
171
+ prompt = """Write a function called is_palindrome that takes a single string as input and returns True if the string is a palindrome, and False otherwise.
172
+
173
+ Palindrome Definition:
174
+
175
+ A palindrome is a word, phrase, number, or other sequence of characters that reads the same forward and backward, ignoring spaces, punctuation, and capitalization.
176
+
177
+ Example:
178
+ ```
179
+ is_palindrome("racecar") # Returns True
180
+ is_palindrome("hello") # Returns False
181
+ is_palindrome("A man, a plan, a canal: Panama") # Returns True
182
+ ```
183
+ """
184
+
185
+ # Format the chat prompt using the tokenizer's chat template
186
+ messages = [
187
+ {"role": "system", "content": "You are a helpful assistant."},
188
+ {"role": "user", "content": prompt},
189
+ ]
190
+ formatted_prompt = tokenizer.apply_chat_template(
191
+ messages, tokenize=False, add_generation_prompt=True
192
+ )
193
+ print(f"Formatted prompt: {formatted_prompt}")
194
+
195
+ # Generate response
196
+ model.eval()
197
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
198
+ with torch.no_grad():
199
+ outputs = model.generate(
200
+ **inputs,
201
+ max_new_tokens=500,
202
+ temperature=0.7,
203
+ top_p=0.9,
204
+ do_sample=True,
205
+ pad_token_id=tokenizer.eos_token_id,
206
+ )
207
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
208
+ print("\nGenerated Response:")
209
+ print(response)
210
+
211
+ model.push_to_hub("burtenshaw/Qwen3-30B-A3B-python-code")