Di Zhang commited on
Commit
bda8afc
·
verified ·
1 Parent(s): 15cdd1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -19
app.py CHANGED
@@ -24,7 +24,6 @@ model = AutoModelForCausalLM.from_pretrained(
24
  DESCRIPTION = '''
25
  # SimpleBerry/LLaMA-O1-Supervised-1129 | Optimized for Streaming and Hugging Face Zero Space.
26
  This model is experimental and focused on advancing AI reasoning capabilities.
27
-
28
  **To start a new chat**, click "clear" and begin a fresh dialogue.
29
  '''
30
 
@@ -38,12 +37,16 @@ def llama_o1_template(data):
38
  text = template.format(content=data)
39
  return text
40
 
41
-
42
  @spaces.GPU
43
- def gen_one_token(inputs,temperature,top_p):
44
- output = model.generate(
 
 
 
 
 
45
  **inputs,
46
- max_new_tokens=1,
47
  temperature=temperature,
48
  top_p=top_p,
49
  do_sample=True,
@@ -51,19 +54,10 @@ def gen_one_token(inputs,temperature,top_p):
51
  pad_token_id=tokenizer.eos_token_id,
52
  return_dict_in_generate=True,
53
  output_scores=False
54
- )
55
- return output
56
-
57
-
58
- def generate_text(message, history, max_tokens=512, temperature=0.9, top_p=0.95):
59
- input_text = llama_o1_template(message)
60
- for i in range(max_tokens):
61
- inputs = tokenizer(input_text, return_tensors="pt").to(accelerator.device)
62
- output = gen_one_token(inputs,temperature,top_p)
63
- # Return text with special tokens included
64
- generated_text = tokenizer.decode(output, skip_special_tokens=False)
65
- input_text += generated_text
66
- yield generated_text
67
 
68
  with gr.Blocks() as demo:
69
  gr.Markdown(DESCRIPTION)
@@ -89,4 +83,4 @@ with gr.Blocks() as demo:
89
  gr.Markdown(LICENSE)
90
 
91
  if __name__ == "__main__":
92
- demo.launch()
 
24
  DESCRIPTION = '''
25
  # SimpleBerry/LLaMA-O1-Supervised-1129 | Optimized for Streaming and Hugging Face Zero Space.
26
  This model is experimental and focused on advancing AI reasoning capabilities.
 
27
  **To start a new chat**, click "clear" and begin a fresh dialogue.
28
  '''
29
 
 
37
  text = template.format(content=data)
38
  return text
39
 
 
40
  @spaces.GPU
41
+ def generate_text(message, history, max_tokens=512, temperature=0.9, top_p=0.95):
42
+ input_text = llama_o1_template(message)
43
+ inputs = tokenizer(input_text, return_tensors="pt").to(accelerator.device)
44
+
45
+ # Stream generation, token by token
46
+ with torch.no_grad():
47
+ for output in model.generate(
48
  **inputs,
49
+ max_length=max_tokens,
50
  temperature=temperature,
51
  top_p=top_p,
52
  do_sample=True,
 
54
  pad_token_id=tokenizer.eos_token_id,
55
  return_dict_in_generate=True,
56
  output_scores=False
57
+ ):
58
+ # Return text with special tokens included
59
+ generated_text = tokenizer.decode(output, skip_special_tokens=False)
60
+ yield generated_text
 
 
 
 
 
 
 
 
 
61
 
62
  with gr.Blocks() as demo:
63
  gr.Markdown(DESCRIPTION)
 
83
  gr.Markdown(LICENSE)
84
 
85
  if __name__ == "__main__":
86
+ demo.launch()