charanhu commited on
Commit
3ac57f8
·
1 Parent(s): bb14ba0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -15,7 +15,7 @@ class StopOnTokens(StoppingCriteria):
15
  return True
16
  return False
17
 
18
- def predict(message, history):
19
 
20
  history_transformer_format = history + [[message, ""]]
21
  stop = StopOnTokens()
@@ -24,15 +24,14 @@ def predict(message, history):
24
  for item in history_transformer_format])
25
 
26
  model_inputs = tokenizer([messages], return_tensors="pt")
27
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
28
  generate_kwargs = dict(
29
  model_inputs,
30
- streamer=streamer,
31
- max_new_tokens=1024,
32
  do_sample=True,
33
  top_p=1,
34
  top_k=50,
35
- temperature=1.0,
36
  num_beams=1,
37
  stopping_criteria=StoppingCriteriaList([stop])
38
  )
@@ -46,4 +45,12 @@ def predict(message, history):
46
  yield partial_message
47
 
48
 
49
- gr.ChatInterface(predict).queue().launch()
 
 
 
 
 
 
 
 
 
15
  return True
16
  return False
17
 
18
+ def predict(message, history, temperature, max_new_tokens, min_new_tokens):
19
 
20
  history_transformer_format = history + [[message, ""]]
21
  stop = StopOnTokens()
 
24
  for item in history_transformer_format])
25
 
26
  model_inputs = tokenizer([messages], return_tensors="pt")
 
27
  generate_kwargs = dict(
28
  model_inputs,
29
+ max_new_tokens=int(max_new_tokens),
30
+ min_new_tokens=int(min_new_tokens),
31
  do_sample=True,
32
  top_p=1,
33
  top_k=50,
34
+ temperature=float(temperature),
35
  num_beams=1,
36
  stopping_criteria=StoppingCriteriaList([stop])
37
  )
 
45
  yield partial_message
46
 
47
 
48
+ iface = gr.ChatInterface(
49
+ fn=predict,
50
+ inputs=["text", "text", gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Temperature"),
51
+ gr.Slider(minimum=1, maximum=2048, value=1024, label="Max Tokens"),
52
+ gr.Slider(minimum=1, maximum=1024, value=1, label="Min Tokens")],
53
+ outputs="text"
54
+ )
55
+
56
+ iface.queue().launch()