Di Zhang commited on
Commit
669aad1
·
verified ·
1 Parent(s): db4b49f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -0
app.py CHANGED
@@ -47,7 +47,28 @@ def format_response(response):
47
  response = response.replace('<negative_rating>','👎')
48
 
49
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def generate_text(message, history, max_tokens=512, temperature=0.9, top_p=0.95):
 
 
 
 
 
51
  input_text = llama_o1_template(message)
52
  inputs = tokenizer(input_text, return_tensors="pt").to(accelerator.device)
53
 
 
47
  response = response.replace('<negative_rating>','👎')
48
 
49
  @spaces.GPU
50
+ def generate_text_gpu(message, history, max_tokens=512, temperature=0.9, top_p=0.95):
51
+ input_text = llama_o1_template(message)
52
+ inputs = tokenizer(input_text, return_tensors="pt").to(accelerator.device)
53
+
54
+ # Generate the text with the model
55
+ output = model.generate(
56
+ **inputs,
57
+ max_length=max_tokens,
58
+ temperature=temperature,
59
+ top_p=top_p,
60
+ do_sample=True,
61
+ )
62
+
63
+ response = tokenizer.decode(output[0], skip_special_tokens=False)
64
+ yield response
65
+
66
  def generate_text(message, history, max_tokens=512, temperature=0.9, top_p=0.95):
67
+ try:
68
+ yield generate_text_gpu(message, history, max_tokens=512, temperature=0.9, top_p=0.95)
69
+ return
70
+ except Exception as e:
71
+ print(e)
72
  input_text = llama_o1_template(message)
73
  inputs = tokenizer(input_text, return_tensors="pt").to(accelerator.device)
74