File size: 7,778 Bytes
5755412 4cf237b 13b1681 f82c314 13b1681 f82c314 eccd8f6 13b1681 4e66e3d 13b1681 4e66e3d 4cf237b 13b1681 b446d41 13b1681 f82c314 b26485f 13b1681 b26485f 4cf237b 13b1681 f82c314 13b1681 b26485f 13b1681 b26485f 13b1681 b26485f 13b1681 5755412 eccd8f6 13b1681 eccd8f6 f82c314 13b1681 4e66e3d f82c314 13b1681 4e66e3d f82c314 13b1681 f82c314 4e66e3d f82c314 13b1681 f82c314 13b1681 f82c314 13b1681 f82c314 13b1681 f82c314 4cf237b f82c314 13b1681 f82c314 13b1681 f82c314 13b1681 f82c314 13b1681 4cf237b 13b1681 f82c314 13b1681 eccd8f6 4cf237b b26485f 13b1681 b26485f f82c314 13b1681 f82c314 13b1681 4cf237b 13b1681 4cf237b 13b1681 f82c314 b446d41 4cf237b 13b1681 4e66e3d b26485f 13b1681 b26485f 4e66e3d f82c314 4e66e3d f82c314 4e66e3d f82c314 4e66e3d eccd8f6 4cf237b 4e66e3d b446d41 13b1681 b26485f 13b1681 4e66e3d f82c314 13b1681 4e66e3d 4cf237b 4e66e3d f82c314 4e66e3d f82c314 4e66e3d f82c314 eccd8f6 5755412 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
import gradio as gr
import spaces
import torch
from datasets import load_dataset
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
pipeline,
BitsAndBytesConfig, # for 4-bit config
)
# PEFT (LoRA / QLoRA)
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
##############################################################################
# ZeroGPU + QLoRA Example
##############################################################################
TEXT_PIPELINE = None
NUM_EXAMPLES = 50 # We'll train on 50 lines of WikiText-2 for demonstration
@spaces.GPU(duration=600) # up to 10 min
def finetune_small_subset():
"""
1) Loads 'wuhp/myr1' in 4-bit quantization (QLoRA style),
2) Adds LoRA adapters (trainable),
3) Trains on 50 lines of WikiText-2,
4) Saves LoRA adapter to 'finetuned_myr1',
5) Reloads LoRA adapters for inference in a pipeline.
"""
# --- 1) Load WikiText-2 subset ---
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
# We'll define tokenize_fn after we have the tokenizer
# --- 2) Setup 4-bit quantization with BitsAndBytes ---
# This is QLoRA approach: we load the base model in 4-bit
# and attach LoRA adapters for training.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16, # or torch.float16 if preferred
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4", # "nf4" is standard for QLoRA
)
config = AutoConfig.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
trust_remote_code=True
)
# Load model in 4-bit
base_model = AutoModelForCausalLM.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
config=config,
quantization_config=bnb_config, # <--- QLoRA 4-bit
device_map="auto",
trust_remote_code=True
)
# Prepare the model for k-bit training (QLoRA)
# This step disables dropout on some layers, sets up gradients for LN, etc.
base_model = prepare_model_for_kbit_training(base_model)
# --- 3) Create LoRA config & wrap the base model in LoRA adapter ---
# For LLaMA-like models, "q_proj" and "v_proj" are typical. If your model is different,
# adjust target_modules accordingly (maybe "c_attn", "W_pack", "query_key_value", etc.)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=["q_proj", "v_proj"], # Adjust if your model uses different layer names
task_type=TaskType.CAUSAL_LM,
)
lora_model = get_peft_model(base_model, lora_config)
# --- 4) Tokenize dataset ---
def tokenize_fn(ex):
return tokenizer(ex["text"], truncation=True, max_length=512)
ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"])
ds.set_format("torch")
# Data collator
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Training args
training_args = TrainingArguments(
output_dir="finetuned_myr1",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
logging_steps=5,
save_steps=999999,
save_total_limit=1,
fp16=False, # We'll rely on bnb_4bit/bfloat16 for the base model
)
# Trainer
trainer = Trainer(
model=lora_model,
args=training_args,
train_dataset=ds,
data_collator=collator,
)
# --- 5) Train ---
trainer.train()
# Save LoRA adapter + tokenizer
# The 'save_model' would save only the LoRA adapter if using PEFT
trainer.model.save_pretrained("finetuned_myr1")
tokenizer.save_pretrained("finetuned_myr1")
# --- 6) Reload the base model in 4-bit, then merge or apply the LoRA adapter for inference
# We'll do the same approach, then load adapter from 'finetuned_myr1'
base_model_2 = AutoModelForCausalLM.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
config=config,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
base_model_2 = prepare_model_for_kbit_training(base_model_2)
# Re-inject LoRA
# If your LoRA was saved in the same folder, you can do:
# from peft import PeftModel
# lora_model_2 = PeftModel.from_pretrained(base_model_2, "finetuned_myr1")
# or you can do get_peft_model and pass the weights, etc.
# But we can reuse 'get_peft_model' + load the LoRA weights
lora_model_2 = get_peft_model(base_model_2, lora_config)
lora_model_2.load_adapter("finetuned_myr1")
# Create pipeline
global TEXT_PIPELINE
TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer)
return "Finetuning complete (QLoRA + LoRA). Model loaded for inference."
def ensure_pipeline():
"""
If we haven't finetuned yet (TEXT_PIPELINE is None),
load the base model in 4-bit with NO LoRA.
"""
global TEXT_PIPELINE
if TEXT_PIPELINE is None:
# Just load base model in 4-bit
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
config=config,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
TEXT_PIPELINE = pipeline("text-generation", model=base_model, tokenizer=tokenizer)
return TEXT_PIPELINE
@spaces.GPU(duration=120) # up to 2 min for text generation
def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
"""
Generates text from the finetuned (LoRA) model if present, else the base model.
"""
pipe = ensure_pipeline()
out = pipe(
prompt,
temperature=float(temperature),
top_p=float(top_p),
min_new_tokens=int(min_new_tokens),
max_new_tokens=int(max_new_tokens),
do_sample=True
)
return out[0]["generated_text"]
# Build Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## ZeroGPU QLoRA Example for wuhp/myr1")
finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on 50 lines of WikiText-2 (up to 10 min)")
status_box = gr.Textbox(label="Finetune Status")
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
gr.Markdown("Then generate text below (or skip finetuning to see base model).")
prompt_in = gr.Textbox(lines=3, label="Prompt")
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
min_tokens = gr.Slider(260, 5000, value=260, step=10, label="Min New Tokens")
max_tokens = gr.Slider(260, 5000, value=500, step=50, label="Max New Tokens")
output_box = gr.Textbox(label="Generated Text", lines=12)
gen_btn = gr.Button("Generate")
gen_btn.click(
fn=predict,
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
outputs=output_box
)
demo.launch()
|