charanhu commited on
Commit
9ad03b5
·
1 Parent(s): 5995eff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -18
app.py CHANGED
@@ -6,6 +6,7 @@ from threading import Thread
6
  # Load model and tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
8
  model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
 
9
 
10
  class StopOnTokens(StoppingCriteria):
11
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
@@ -15,7 +16,7 @@ class StopOnTokens(StoppingCriteria):
15
  return True
16
  return False
17
 
18
- def predict(message, history, temperature, max_new_tokens, min_new_tokens):
19
 
20
  history_transformer_format = history + [[message, ""]]
21
  stop = StopOnTokens()
@@ -23,15 +24,18 @@ def predict(message, history, temperature, max_new_tokens, min_new_tokens):
23
  messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
24
  for item in history_transformer_format])
25
 
26
- model_inputs = tokenizer([messages], return_tensors="pt")
 
 
 
27
  generate_kwargs = dict(
28
  model_inputs,
29
- max_new_tokens=int(max_new_tokens),
30
- min_new_tokens=int(min_new_tokens),
31
  do_sample=True,
32
  top_p=1,
33
  top_k=50,
34
- temperature=float(temperature),
35
  num_beams=1,
36
  stopping_criteria=StoppingCriteriaList([stop])
37
  )
@@ -45,16 +49,4 @@ def predict(message, history, temperature, max_new_tokens, min_new_tokens):
45
  yield partial_message
46
 
47
 
48
- iface = gr.Interface(
49
- fn=predict,
50
- inputs=["text", "text", gr.Slider(minimum=0.1, maximum=2.0, default=1.0, label="Temperature"),
51
- gr.Slider(minimum=1, maximum=2048, default=1024, label="Max Tokens"),
52
- gr.Slider(minimum=1, maximum=1024, default=1, label="Min Tokens")],
53
- outputs="text",
54
- live=True,
55
- capture_session=True,
56
- layout="vertical",
57
- chat=True
58
- )
59
-
60
- iface.launch()
 
6
  # Load model and tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
8
  model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
9
+ model = model.to('cuda:0')
10
 
11
  class StopOnTokens(StoppingCriteria):
12
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
16
  return True
17
  return False
18
 
19
+ def predict(message, history):
20
 
21
  history_transformer_format = history + [[message, ""]]
22
  stop = StopOnTokens()
 
24
  messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
25
  for item in history_transformer_format])
26
 
27
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
28
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Temperature"),
29
+ max_new_tokens = gr.Slider(minimum=0, maximum=2048, value=10, label="Temperature"),
30
+ min_new_tokens = gr.Slider(minimum=0, maximum=2048, value=1, label="Temperature"),
31
  generate_kwargs = dict(
32
  model_inputs,
33
+ max_new_tokens=int(max_new_tokens.value),
34
+ min_new_tokens=int(min_new_tokens.value),
35
  do_sample=True,
36
  top_p=1,
37
  top_k=50,
38
+ temperature=float(temperature.value),
39
  num_beams=1,
40
  stopping_criteria=StoppingCriteriaList([stop])
41
  )
 
49
  yield partial_message
50
 
51
 
52
+ gr.ChatInterface(predict).queue().launch()