Di Zhang
Update app.py
15cdd1d verified
raw
history blame
3.35 kB
import spaces
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import snapshot_download
import torch
from accelerate import Accelerator
# Initialize Accelerator for efficient multi-GPU/Zero optimization
accelerator = Accelerator()
# Load the model and tokenizer
model_path = snapshot_download(
repo_id=os.environ.get("REPO_ID", "SimpleBerry/LLaMA-O1-Supervised-1129")
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto"
).eval()
DESCRIPTION = '''
# SimpleBerry/LLaMA-O1-Supervised-1129 | Optimized for Streaming and Hugging Face Zero Space.
This model is experimental and focused on advancing AI reasoning capabilities.
**To start a new chat**, click "clear" and begin a fresh dialogue.
'''
LICENSE = """
--- MIT License ---
"""
template = "<start_of_father_id>-1<end_of_father_id><start_of_local_id>0<end_of_local_id><start_of_thought><problem>{content}<end_of_thought><start_of_rating><positive_rating><end_of_rating>\n<start_of_father_id>0<end_of_father_id><start_of_local_id>1<end_of_local_id><start_of_thought><expansion>"
def llama_o1_template(data):
text = template.format(content=data)
return text
@spaces.GPU
def gen_one_token(inputs,temperature,top_p):
output = model.generate(
**inputs,
max_new_tokens=1,
temperature=temperature,
top_p=top_p,
do_sample=True,
use_cache=True,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=False
)
return output
def generate_text(message, history, max_tokens=512, temperature=0.9, top_p=0.95):
input_text = llama_o1_template(message)
for i in range(max_tokens):
inputs = tokenizer(input_text, return_tensors="pt").to(accelerator.device)
output = gen_one_token(inputs,temperature,top_p)
# Return text with special tokens included
generated_text = tokenizer.decode(output, skip_special_tokens=False)
input_text += generated_text
yield generated_text
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
chatbot = gr.ChatInterface(
generate_text,
title="SimpleBerry/LLaMA-O1-Supervised-1129 | Optimized Demo",
description="Adjust settings below as needed.",
examples=[
["How many r's are in the word strawberry?"],
['If Diana needs to bike 10 miles to reach home and she can bike at a speed of 3 mph for two hours before getting tired, and then at a speed of 1 mph until she reaches home, how long will it take her to get home?'],
['Find the least odd prime factor of $2019^8+1$.'],
],
cache_examples=False,
fill_height=True
)
with gr.Accordion("Adjust Parameters", open=False):
max_tokens_slider = gr.Slider(minimum=128, maximum=2048, value=512, step=1, label="Max Tokens")
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.9, step=0.1, label="Temperature")
top_p_slider = gr.Slider(minimum=0.05, maximum=1.0, value=0.95, step=0.01, label="Top-p (nucleus sampling)")
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.launch()