|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
MODEL_REPO = "wuhp/myr1" |
|
SUBFOLDER = "myr1" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_REPO, |
|
subfolder=SUBFOLDER, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_REPO, |
|
subfolder=SUBFOLDER, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
model.eval() |
|
|
|
def generate_text(prompt, max_length=64, temperature=0.7, top_p=0.9): |
|
print("=== Starting generation ===") |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
try: |
|
output_ids = model.generate( |
|
**inputs, |
|
max_new_tokens=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
print("=== Generation complete ===") |
|
except Exception as e: |
|
print(f"Error during generation: {e}") |
|
return str(e) |
|
return tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
demo = gr.Interface( |
|
fn=generate_text, |
|
inputs=[ |
|
gr.Textbox( |
|
lines=4, |
|
label="Prompt", |
|
placeholder="Try a short prompt, e.g., Hello!" |
|
), |
|
gr.Slider(8, 512, value=64, step=1, label="Max New Tokens"), |
|
gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature"), |
|
gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p"), |
|
], |
|
outputs="text", |
|
title="DeepSeek R1 Demo", |
|
description="Generates text using the large DeepSeek model." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|