myr1-2 / app.py
wuhp's picture
Update app.py
13b1681 verified
raw
history blame
7.78 kB
import gradio as gr
import spaces
import torch
from datasets import load_dataset
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
pipeline,
BitsAndBytesConfig, # for 4-bit config
)
# PEFT (LoRA / QLoRA)
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
##############################################################################
# ZeroGPU + QLoRA Example
##############################################################################
TEXT_PIPELINE = None
NUM_EXAMPLES = 50 # We'll train on 50 lines of WikiText-2 for demonstration
@spaces.GPU(duration=600) # up to 10 min
def finetune_small_subset():
"""
1) Loads 'wuhp/myr1' in 4-bit quantization (QLoRA style),
2) Adds LoRA adapters (trainable),
3) Trains on 50 lines of WikiText-2,
4) Saves LoRA adapter to 'finetuned_myr1',
5) Reloads LoRA adapters for inference in a pipeline.
"""
# --- 1) Load WikiText-2 subset ---
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
# We'll define tokenize_fn after we have the tokenizer
# --- 2) Setup 4-bit quantization with BitsAndBytes ---
# This is QLoRA approach: we load the base model in 4-bit
# and attach LoRA adapters for training.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16, # or torch.float16 if preferred
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4", # "nf4" is standard for QLoRA
)
config = AutoConfig.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
trust_remote_code=True
)
# Load model in 4-bit
base_model = AutoModelForCausalLM.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
config=config,
quantization_config=bnb_config, # <--- QLoRA 4-bit
device_map="auto",
trust_remote_code=True
)
# Prepare the model for k-bit training (QLoRA)
# This step disables dropout on some layers, sets up gradients for LN, etc.
base_model = prepare_model_for_kbit_training(base_model)
# --- 3) Create LoRA config & wrap the base model in LoRA adapter ---
# For LLaMA-like models, "q_proj" and "v_proj" are typical. If your model is different,
# adjust target_modules accordingly (maybe "c_attn", "W_pack", "query_key_value", etc.)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=["q_proj", "v_proj"], # Adjust if your model uses different layer names
task_type=TaskType.CAUSAL_LM,
)
lora_model = get_peft_model(base_model, lora_config)
# --- 4) Tokenize dataset ---
def tokenize_fn(ex):
return tokenizer(ex["text"], truncation=True, max_length=512)
ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"])
ds.set_format("torch")
# Data collator
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Training args
training_args = TrainingArguments(
output_dir="finetuned_myr1",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
logging_steps=5,
save_steps=999999,
save_total_limit=1,
fp16=False, # We'll rely on bnb_4bit/bfloat16 for the base model
)
# Trainer
trainer = Trainer(
model=lora_model,
args=training_args,
train_dataset=ds,
data_collator=collator,
)
# --- 5) Train ---
trainer.train()
# Save LoRA adapter + tokenizer
# The 'save_model' would save only the LoRA adapter if using PEFT
trainer.model.save_pretrained("finetuned_myr1")
tokenizer.save_pretrained("finetuned_myr1")
# --- 6) Reload the base model in 4-bit, then merge or apply the LoRA adapter for inference
# We'll do the same approach, then load adapter from 'finetuned_myr1'
base_model_2 = AutoModelForCausalLM.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
config=config,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
base_model_2 = prepare_model_for_kbit_training(base_model_2)
# Re-inject LoRA
# If your LoRA was saved in the same folder, you can do:
# from peft import PeftModel
# lora_model_2 = PeftModel.from_pretrained(base_model_2, "finetuned_myr1")
# or you can do get_peft_model and pass the weights, etc.
# But we can reuse 'get_peft_model' + load the LoRA weights
lora_model_2 = get_peft_model(base_model_2, lora_config)
lora_model_2.load_adapter("finetuned_myr1")
# Create pipeline
global TEXT_PIPELINE
TEXT_PIPELINE = pipeline("text-generation", model=lora_model_2, tokenizer=tokenizer)
return "Finetuning complete (QLoRA + LoRA). Model loaded for inference."
def ensure_pipeline():
"""
If we haven't finetuned yet (TEXT_PIPELINE is None),
load the base model in 4-bit with NO LoRA.
"""
global TEXT_PIPELINE
if TEXT_PIPELINE is None:
# Just load base model in 4-bit
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
config = AutoConfig.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("wuhp/myr1", subfolder="myr1", trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
config=config,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
TEXT_PIPELINE = pipeline("text-generation", model=base_model, tokenizer=tokenizer)
return TEXT_PIPELINE
@spaces.GPU(duration=120) # up to 2 min for text generation
def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
"""
Generates text from the finetuned (LoRA) model if present, else the base model.
"""
pipe = ensure_pipeline()
out = pipe(
prompt,
temperature=float(temperature),
top_p=float(top_p),
min_new_tokens=int(min_new_tokens),
max_new_tokens=int(max_new_tokens),
do_sample=True
)
return out[0]["generated_text"]
# Build Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## ZeroGPU QLoRA Example for wuhp/myr1")
finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on 50 lines of WikiText-2 (up to 10 min)")
status_box = gr.Textbox(label="Finetune Status")
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
gr.Markdown("Then generate text below (or skip finetuning to see base model).")
prompt_in = gr.Textbox(lines=3, label="Prompt")
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
min_tokens = gr.Slider(260, 5000, value=260, step=10, label="Min New Tokens")
max_tokens = gr.Slider(260, 5000, value=500, step=50, label="Max New Tokens")
output_box = gr.Textbox(label="Generated Text", lines=12)
gen_btn = gr.Button("Generate")
gen_btn.click(
fn=predict,
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
outputs=output_box
)
demo.launch()