|
import spaces |
|
import os |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from huggingface_hub import snapshot_download |
|
import torch |
|
from accelerate import Accelerator |
|
|
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
model_path = snapshot_download( |
|
repo_id=os.environ.get("REPO_ID", "SimpleBerry/LLaMA-O1-Supervised-1129") |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
).eval() |
|
|
|
DESCRIPTION = ''' |
|
# SimpleBerry/LLaMA-O1-Supervised-1129 | Optimized for Streaming and Hugging Face Zero Space. |
|
This model is experimental and focused on advancing AI reasoning capabilities. |
|
|
|
**To start a new chat**, click "clear" and begin a fresh dialogue. |
|
''' |
|
|
|
LICENSE = """ |
|
--- MIT License --- |
|
""" |
|
|
|
template = "<start_of_father_id>-1<end_of_father_id><start_of_local_id>0<end_of_local_id><start_of_thought><problem>{content}<end_of_thought><start_of_rating><positive_rating><end_of_rating>\n<start_of_father_id>0<end_of_father_id><start_of_local_id>1<end_of_local_id><start_of_thought><expansion>" |
|
|
|
def llama_o1_template(data): |
|
text = template.format(content=data) |
|
return text |
|
|
|
|
|
@spaces.GPU |
|
def gen_one_token(inputs,temperature,top_p): |
|
output = model.generate( |
|
**inputs, |
|
max_new_tokens=1, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
use_cache=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
return_dict_in_generate=True, |
|
output_scores=False |
|
) |
|
return output |
|
|
|
|
|
def generate_text(message, history, max_tokens=512, temperature=0.9, top_p=0.95): |
|
input_text = llama_o1_template(message) |
|
for i in range(max_tokens): |
|
inputs = tokenizer(input_text, return_tensors="pt").to(accelerator.device) |
|
output = gen_one_token(inputs,temperature,top_p) |
|
|
|
generated_text = tokenizer.decode(output, skip_special_tokens=False) |
|
input_text += generated_text |
|
yield generated_text |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
chatbot = gr.ChatInterface( |
|
generate_text, |
|
title="SimpleBerry/LLaMA-O1-Supervised-1129 | Optimized Demo", |
|
description="Adjust settings below as needed.", |
|
examples=[ |
|
["How many r's are in the word strawberry?"], |
|
['If Diana needs to bike 10 miles to reach home and she can bike at a speed of 3 mph for two hours before getting tired, and then at a speed of 1 mph until she reaches home, how long will it take her to get home?'], |
|
['Find the least odd prime factor of $2019^8+1$.'], |
|
], |
|
cache_examples=False, |
|
fill_height=True |
|
) |
|
|
|
with gr.Accordion("Adjust Parameters", open=False): |
|
max_tokens_slider = gr.Slider(minimum=128, maximum=2048, value=512, step=1, label="Max Tokens") |
|
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.9, step=0.1, label="Temperature") |
|
top_p_slider = gr.Slider(minimum=0.05, maximum=1.0, value=0.95, step=0.01, label="Top-p (nucleus sampling)") |
|
|
|
gr.Markdown(LICENSE) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|