File size: 5,560 Bytes
5755412 f82c314 4cf237b f82c314 eccd8f6 4e66e3d 4cf237b 4e66e3d 4cf237b b446d41 4cf237b f82c314 4cf237b f82c314 b26485f 4cf237b b26485f 4cf237b 4e66e3d f82c314 4cf237b f82c314 4e66e3d b26485f 4e66e3d b26485f 4e66e3d b26485f 4cf237b eccd8f6 5755412 eccd8f6 4cf237b eccd8f6 f82c314 4cf237b 4e66e3d f82c314 4e66e3d f82c314 4cf237b f82c314 4e66e3d f82c314 4cf237b f82c314 4cf237b f82c314 4cf237b f82c314 4cf237b f82c314 4e66e3d f82c314 4cf237b f82c314 4cf237b f82c314 4cf237b f82c314 4e66e3d eccd8f6 4cf237b b26485f 4cf237b b26485f f82c314 4cf237b 4e66e3d f82c314 4cf237b 4e66e3d f82c314 b446d41 4cf237b 4e66e3d b26485f 4cf237b b26485f 4e66e3d f82c314 4e66e3d f82c314 4e66e3d f82c314 4e66e3d eccd8f6 4cf237b 4e66e3d b446d41 4cf237b b26485f 4cf237b 4e66e3d f82c314 4cf237b 4e66e3d 4cf237b 4e66e3d f82c314 4e66e3d f82c314 4e66e3d f82c314 eccd8f6 5755412 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gradio as gr
import spaces
from datasets import load_dataset
import torch
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
pipeline
)
##############################################################################
# GLOBALS / ZERO-GPU APPROACH
##############################################################################
# We store a global pipeline after finetuning (if any).
TEXT_PIPELINE = None
# We'll train on only 50 examples from WikiText-2 to keep it short.
NUM_EXAMPLES = 50
@spaces.GPU(duration=600) # up to 600 seconds (10 minutes) for mini-finetraining
def finetune_small_subset():
"""
1) Loads 'wuhp/myr1' in 8-bit,
2) Takes 50 examples from WikiText-2,
3) Finetunes for 1 epoch,
4) Saves to 'finetuned_myr1/',
5) Reloads the new model into a pipeline for inference.
"""
# 1) Load dataset
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
# Keep only 50 to fit ephemeral GPU time
ds = ds.select(range(min(NUM_EXAMPLES, len(ds))))
# 2) Load config, tokenizer, model
config = AutoConfig.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
trust_remote_code=True
)
# 8-bit loading via bitsandbytes
model = AutoModelForCausalLM.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
config=config,
load_in_8bit=True, # <--- 8-bit
device_map="auto", # let HF manage device placement
trust_remote_code=True
)
# 3) Tokenize
def tokenize_fn(ex):
return tokenizer(ex["text"], truncation=True, max_length=512)
ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"])
ds.set_format("torch")
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# 4) TrainingArguments: no fp16 to avoid half-precision gradient issues
training_args = TrainingArguments(
output_dir="finetuned_myr1",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
logging_steps=10,
save_steps=999999, # skip mid-training saves
save_total_limit=1,
fp16=False, # <--- disable FP16
)
# 5) Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=ds,
data_collator=collator,
)
# 6) Train
trainer.train()
# 7) Save final model
trainer.save_model("finetuned_myr1")
tokenizer.save_pretrained("finetuned_myr1")
# 8) Reload the newly finetuned model as a pipeline (for inference)
finetuned_model = AutoModelForCausalLM.from_pretrained(
"finetuned_myr1",
device_map="auto",
trust_remote_code=True
)
global TEXT_PIPELINE
TEXT_PIPELINE = pipeline("text-generation", model=finetuned_model, tokenizer=tokenizer)
return "Finetuning complete! Model reloaded for inference."
def ensure_pipeline():
"""
If no pipeline yet, load the original model from wuhp/myr1 for inference.
(In 8-bit or normal float? We can do normal float here for a simpler approach.)
"""
global TEXT_PIPELINE
if TEXT_PIPELINE is None:
tokenizer = AutoTokenizer.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"wuhp/myr1",
subfolder="myr1",
trust_remote_code=True,
load_in_8bit=True, # load in 8-bit also for inference
device_map="auto"
)
TEXT_PIPELINE = pipeline("text-generation", model=model, tokenizer=tokenizer)
return TEXT_PIPELINE
@spaces.GPU(duration=120) # up to 120s for text generation
def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
"""
Generates text from either the finetuned pipeline (if it exists) or the base model.
Allows user to adjust temperature, top_p, min/max tokens.
"""
pipe = ensure_pipeline()
out = pipe(
prompt,
temperature=float(temperature),
top_p=float(top_p),
min_new_tokens=int(min_new_tokens),
max_new_tokens=int(max_new_tokens),
do_sample=True
)
return out[0]["generated_text"]
# Build Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## ZeroGPU: Mini-Finetune with 8-bit + Extended Generation")
finetune_btn = gr.Button("Finetune on 50 lines of WikiText-2 (up to 10 min)")
status_box = gr.Textbox(label="Finetune Status")
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
gr.Markdown("After finetuning, or even without it, generate text below:")
prompt_in = gr.Textbox(lines=3, label="Prompt")
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
min_tokens = gr.Slider(260, 5000, value=260, step=10, label="Min New Tokens")
max_tokens = gr.Slider(260, 5000, value=500, step=50, label="Max New Tokens")
output_box = gr.Textbox(label="Generated Text", lines=12)
gen_btn = gr.Button("Generate")
gen_btn.click(
fn=predict,
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
outputs=output_box
)
demo.launch()
|