import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# ---------------------------------------------------------------- | |
# 1) Points to your Hugging Face repo and subfolder | |
# (where config.json, tokenizer.json, model safetensors, etc. reside). | |
# ---------------------------------------------------------------- | |
MODEL_REPO = "wuhp/myr1" | |
SUBFOLDER = "myr1" | |
# ---------------------------------------------------------------- | |
# 2) Load the tokenizer | |
# trust_remote_code=True allows custom code (e.g., DeepSeek config/classes). | |
# ---------------------------------------------------------------- | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_REPO, | |
subfolder=SUBFOLDER, | |
trust_remote_code=True | |
) | |
# ---------------------------------------------------------------- | |
# 3) Load the model | |
# - device_map="auto" tries to place layers on GPU and offload remainder to CPU if needed | |
# - torch_dtype can be float16, float32, bfloat16, etc., depending on GPU support | |
# ---------------------------------------------------------------- | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_REPO, | |
subfolder=SUBFOLDER, | |
trust_remote_code=True, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
) | |
# Put model in evaluation mode | |
model.eval() | |
# ---------------------------------------------------------------- | |
# 4) Define the generation function | |
# ---------------------------------------------------------------- | |
def generate_text(prompt, max_length=64, temperature=0.7, top_p=0.9): | |
print("=== Starting generation ===") | |
# Move input tokens to the same device as model | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
try: | |
# Generate tokens | |
output_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_length, # This controls how many tokens beyond the prompt are generated | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
print("=== Generation complete ===") | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
return str(e) | |
# Decode back to text (skipping special tokens) | |
return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
# ---------------------------------------------------------------- | |
# 5) Build a Gradio UI | |
# ---------------------------------------------------------------- | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox( | |
lines=4, | |
label="Prompt", | |
placeholder="Try a short prompt, e.g., Hello!" | |
), | |
gr.Slider(8, 512, value=64, step=1, label="Max New Tokens"), | |
gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p"), | |
], | |
outputs="text", | |
title="DeepSeek R1 Demo", | |
description="Generates text using the large DeepSeek model." | |
) | |
# ---------------------------------------------------------------- | |
# 6) Run the Gradio app | |
# ---------------------------------------------------------------- | |
if __name__ == "__main__": | |
demo.launch() | |