Di Zhang commited on
Commit
bf7cf6b
·
verified ·
1 Parent(s): 22dfef8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -30
app.py CHANGED
@@ -1,30 +1,31 @@
1
-
2
  import spaces
3
-
4
  import os
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
- from huggingface_hub import hf_hub_download, snapshot_download
8
- import accelerate
 
9
 
10
- accelerator = accelerate.Accelerator()
 
11
 
12
- # Load the model and tokenizer from Hugging Face
13
  model_path = snapshot_download(
14
  repo_id=os.environ.get("REPO_ID", "SimpleBerry/LLaMA-O1-Supervised-1129")
15
  )
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_path)
18
- model = AutoModelForCausalLM.from_pretrained(model_path,device_map='auto')
 
 
 
 
19
 
20
  DESCRIPTION = '''
21
- # SimpleBerry/LLaMA-O1-Supervised-1129 | Duplicate the space and set it to private for faster & personal inference for free.
22
- SimpleBerry/LLaMA-O1-Supervised-1129: an experimental research model developed by the SimpleBerry.
23
- Focused on advancing AI reasoning capabilities.
24
-
25
- ## This Space was designed by Lyte/LLaMA-O1-Supervised-1129-GGUF, Many Thanks!
26
 
27
- **To start a new chat**, click "clear" and start a new dialogue.
28
  '''
29
 
30
  LICENSE = """
@@ -34,7 +35,6 @@ LICENSE = """
34
  template = "<start_of_father_id>-1<end_of_father_id><start_of_local_id>0<end_of_local_id><start_of_thought><problem>{content}<end_of_thought><start_of_rating><positive_rating><end_of_rating>\n<start_of_father_id>0<end_of_father_id><start_of_local_id>1<end_of_local_id><start_of_thought><expansion>"
35
 
36
  def llama_o1_template(data):
37
- #query = data['query']
38
  text = template.format(content=data)
39
  return text
40
 
@@ -43,25 +43,30 @@ def generate_text(message, history, max_tokens=512, temperature=0.9, top_p=0.95)
43
  input_text = llama_o1_template(message)
44
  inputs = tokenizer(input_text, return_tensors="pt").to(accelerator.device)
45
 
46
- # Generate the text with the model
47
- output = model.generate(
48
- **inputs,
49
- max_length=max_tokens,
50
- temperature=temperature,
51
- top_p=top_p,
52
- do_sample=True,
53
- )
54
-
55
- response = tokenizer.decode(output[0], skip_special_tokens=True)
56
- yield response
 
 
 
 
 
57
 
58
  with gr.Blocks() as demo:
59
  gr.Markdown(DESCRIPTION)
60
 
61
  chatbot = gr.ChatInterface(
62
  generate_text,
63
- title="SimpleBerry/LLaMA-O1-Supervised-1129 | GGUF Demo",
64
- description="Edit Settings below if needed.",
65
  examples=[
66
  ["How many r's are in the word strawberry?"],
67
  ['If Diana needs to bike 10 miles to reach home and she can bike at a speed of 3 mph for two hours before getting tired, and then at a speed of 1 mph until she reaches home, how long will it take her to get home?'],
@@ -72,9 +77,9 @@ with gr.Blocks() as demo:
72
  )
73
 
74
  with gr.Accordion("Adjust Parameters", open=False):
75
- gr.Slider(minimum=1024, maximum=8192, value=2048, step=1, label="Max Tokens")
76
- gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature")
77
- gr.Slider(minimum=0.05, maximum=1.0, value=0.95, step=0.01, label="Top-p (nucleus sampling)")
78
 
79
  gr.Markdown(LICENSE)
80
 
 
 
1
  import spaces
 
2
  import os
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from huggingface_hub import snapshot_download
6
+ import torch
7
+ from accelerate import Accelerator
8
 
9
+ # Initialize Accelerator for efficient multi-GPU/Zero optimization
10
+ accelerator = Accelerator()
11
 
12
+ # Load the model and tokenizer
13
  model_path = snapshot_download(
14
  repo_id=os.environ.get("REPO_ID", "SimpleBerry/LLaMA-O1-Supervised-1129")
15
  )
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_path)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ model_path,
20
+ torch_dtype=torch.float16,
21
+ device_map="auto"
22
+ ).eval()
23
 
24
  DESCRIPTION = '''
25
+ # SimpleBerry/LLaMA-O1-Supervised-1129 | Optimized for Streaming and Hugging Face Zero Space.
26
+ This model is experimental and focused on advancing AI reasoning capabilities.
 
 
 
27
 
28
+ **To start a new chat**, click "clear" and begin a fresh dialogue.
29
  '''
30
 
31
  LICENSE = """
 
35
  template = "<start_of_father_id>-1<end_of_father_id><start_of_local_id>0<end_of_local_id><start_of_thought><problem>{content}<end_of_thought><start_of_rating><positive_rating><end_of_rating>\n<start_of_father_id>0<end_of_father_id><start_of_local_id>1<end_of_local_id><start_of_thought><expansion>"
36
 
37
  def llama_o1_template(data):
 
38
  text = template.format(content=data)
39
  return text
40
 
 
43
  input_text = llama_o1_template(message)
44
  inputs = tokenizer(input_text, return_tensors="pt").to(accelerator.device)
45
 
46
+ # Stream generation, token by token
47
+ with torch.no_grad():
48
+ for output in model.generate(
49
+ **inputs,
50
+ max_length=max_tokens,
51
+ temperature=temperature,
52
+ top_p=top_p,
53
+ do_sample=True,
54
+ use_cache=True,
55
+ pad_token_id=tokenizer.eos_token_id,
56
+ return_dict_in_generate=True,
57
+ output_scores=False
58
+ ):
59
+ # Return text with special tokens included
60
+ generated_text = tokenizer.decode(output, skip_special_tokens=False)
61
+ yield generated_text
62
 
63
  with gr.Blocks() as demo:
64
  gr.Markdown(DESCRIPTION)
65
 
66
  chatbot = gr.ChatInterface(
67
  generate_text,
68
+ title="SimpleBerry/LLaMA-O1-Supervised-1129 | Optimized Demo",
69
+ description="Adjust settings below as needed.",
70
  examples=[
71
  ["How many r's are in the word strawberry?"],
72
  ['If Diana needs to bike 10 miles to reach home and she can bike at a speed of 3 mph for two hours before getting tired, and then at a speed of 1 mph until she reaches home, how long will it take her to get home?'],
 
77
  )
78
 
79
  with gr.Accordion("Adjust Parameters", open=False):
80
+ max_tokens_slider = gr.Slider(minimum=128, maximum=2048, value=512, step=1, label="Max Tokens")
81
+ temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.9, step=0.1, label="Temperature")
82
+ top_p_slider = gr.Slider(minimum=0.05, maximum=1.0, value=0.95, step=0.01, label="Top-p (nucleus sampling)")
83
 
84
  gr.Markdown(LICENSE)
85