burtenshaw's picture
burtenshaw HF Staff
Upload train.py with huggingface_hub
72bcd86 verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "datasets",
# "httpx",
# "huggingface - hub",
# "setuptools",
# "transformers",
# "torch",
# "accelerate",
# "trl",
# "peft",
# "wandb",
# "torchvision",
# "torchaudio"
# ]
# ///
"""## Import libraries"""
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer, setup_chat_format
from peft import LoraConfig
"""# Load Dataset"""
dataset_name = "allenai/tulu-3-sft-personas-code" # Example dataset
# Load dataset
dataset = load_dataset(dataset_name, split="train")
print(f"Dataset loaded: {dataset}")
# Let's look at a sample
print("\nSample data:")
print(dataset[0])
dataset = dataset.remove_columns("prompt")
dataset = dataset.train_test_split(test_size=0.2)
print(
f"Train Samples: {len(dataset['train'])}\nTest Samples: {len(dataset['test'])}"
)
"""## Configuration
Set up the configuration parameters for the fine-tuning process.
"""
# Model configuration
model_name = "Qwen/Qwen3-30B-A3B" # You can change this to any model you want to fine-tune
# # Other compatible Qwen3 models
# model_name = "Qwen/Qwen3-32B"
# model_name = "Qwen/Qwen3-14B"
# model_name = "Qwen/Qwen3-8B"
# model_name = "Qwen/Qwen3-4B"
# model_name = "Qwen/Qwen3-1.7B"
# model_name = "Qwen/Qwen3-0.6B"
# Training configuration
output_dir = "./tmp/sft-model"
num_train_epochs = 1
per_device_train_batch_size = 1
gradient_accumulation_steps = 1
learning_rate = 2e-4
"""## Load model and tokenizer"""
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
use_cache=False, # Disable KV cache during training
device_map="auto",
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# # Set up chat formatting (if the model doesn't have a chat template)
# if tokenizer.chat_template is None:
# model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
# # Set padding token
# if tokenizer.pad_token is None:
# tokenizer.pad_token = tokenizer.eos_token
"""## Configure PEFT (if enabled)"""
# Set up PEFT configuration if enabled
peft_config = LoraConfig(
r=32, # Rank
lora_alpha=16, # Alpha parameter for LoRA scaling
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)
"""## Configure SFT Trainer"""
# Training arguments
training_args = SFTConfig(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
gradient_checkpointing=True,
logging_steps=25,
save_strategy="epoch",
optim="adamw_torch",
lr_scheduler_type="cosine",
warmup_ratio=0.1,
max_length=1024,
packing=True, # Enable packing to increase training efficiency
eos_token=tokenizer.eos_token,
bf16=True,
fp16=False,
max_steps=1000,
report_to="wandb", # Disable reporting to avoid wandb prompts
)
"""## Initialize and run the SFT Trainer"""
# Create SFT Trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"] if "test" in dataset else None,
peft_config=peft_config,
processing_class=tokenizer,
)
# Train the model
trainer.train()
"""## Save the fine-tuned model"""
# Save the model
trainer.save_model(output_dir)
"""## Test the fine-tuned model"""
from peft import PeftModel, PeftConfig
# Load the base model
base_model = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, torch_dtype=torch.bfloat16
)
# Load the fine-tuned PEFT model
model = PeftModel.from_pretrained(base_model, output_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Test the model with an example
prompt = """Write a function called is_palindrome that takes a single string as input and returns True if the string is a palindrome, and False otherwise.
Palindrome Definition:
A palindrome is a word, phrase, number, or other sequence of characters that reads the same forward and backward, ignoring spaces, punctuation, and capitalization.
Example:
```
is_palindrome("racecar") # Returns True
is_palindrome("hello") # Returns False
is_palindrome("A man, a plan, a canal: Panama") # Returns True
```
"""
# Format the chat prompt using the tokenizer's chat template
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
formatted_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
print(f"Formatted prompt: {formatted_prompt}")
# Generate response
model.eval()
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=500,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("\nGenerated Response:")
print(response)
model.push_to_hub("burtenshaw/Qwen3-30B-A3B-python-code")