Di Zhang commited on
Commit
2aced17
·
verified ·
1 Parent(s): bf7cf6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -38,16 +38,12 @@ def llama_o1_template(data):
38
  text = template.format(content=data)
39
  return text
40
 
41
- @spaces.GPU
42
- def generate_text(message, history, max_tokens=512, temperature=0.9, top_p=0.95):
43
- input_text = llama_o1_template(message)
44
- inputs = tokenizer(input_text, return_tensors="pt").to(accelerator.device)
45
 
46
- # Stream generation, token by token
47
- with torch.no_grad():
48
- for output in model.generate(
49
  **inputs,
50
- max_length=max_tokens,
51
  temperature=temperature,
52
  top_p=top_p,
53
  do_sample=True,
@@ -55,10 +51,19 @@ def generate_text(message, history, max_tokens=512, temperature=0.9, top_p=0.95)
55
  pad_token_id=tokenizer.eos_token_id,
56
  return_dict_in_generate=True,
57
  output_scores=False
58
- ):
59
- # Return text with special tokens included
60
- generated_text = tokenizer.decode(output, skip_special_tokens=False)
61
- yield generated_text
 
 
 
 
 
 
 
 
 
62
 
63
  with gr.Blocks() as demo:
64
  gr.Markdown(DESCRIPTION)
 
38
  text = template.format(content=data)
39
  return text
40
 
 
 
 
 
41
 
42
+ @spaces.GPU
43
+ def gen_one_token(inputs,temperature,top_p)
44
+ output = model.generate(
45
  **inputs,
46
+ max_new_tokens=1,
47
  temperature=temperature,
48
  top_p=top_p,
49
  do_sample=True,
 
51
  pad_token_id=tokenizer.eos_token_id,
52
  return_dict_in_generate=True,
53
  output_scores=False
54
+ )
55
+ return output
56
+
57
+
58
+ def generate_text(message, history, max_tokens=512, temperature=0.9, top_p=0.95):
59
+ input_text = llama_o1_template(message)
60
+ for i in range(max_tokens):
61
+ inputs = tokenizer(input_text, return_tensors="pt").to(accelerator.device)
62
+ output = gen_one_token(inputs,temperature,top_p)
63
+ # Return text with special tokens included
64
+ generated_text = tokenizer.decode(output, skip_special_tokens=False)
65
+ input_text += generated_text
66
+ yield generated_text
67
 
68
  with gr.Blocks() as demo:
69
  gr.Markdown(DESCRIPTION)