jaeyoungk commited on
Commit
c3ac709
·
1 Parent(s): 43c7f9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -47
app.py CHANGED
@@ -1,49 +1,3 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
- from threading import Thread
5
 
6
- tokenizer = AutoTokenizer.from_pretrained("RAIJAY/7B_QA_68348")
7
- model = AutoModelForCausalLM.from_pretrained("RAIJAY/7B_QA_68348", torch_dtype=torch.float16)
8
- model = model.to('cuda:0')
9
-
10
- class StopOnTokens(StoppingCriteria):
11
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
12
- stop_ids = [29, 0]
13
- for stop_id in stop_ids:
14
- if input_ids[0][-1] == stop_id:
15
- return True
16
- return False
17
-
18
- def predict(message, history):
19
-
20
- history_transformer_format = history + [[message, ""]]
21
- stop = StopOnTokens()
22
-
23
- messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
24
- for item in history_transformer_format])
25
-
26
- model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
27
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
28
- generate_kwargs = dict(
29
- model_inputs,
30
- streamer=streamer,
31
- max_new_tokens=1024,
32
- do_sample=True,
33
- top_p=0.95,
34
- top_k=1000,
35
- temperature=1.0,
36
- num_beams=1,
37
- stopping_criteria=StoppingCriteriaList([stop])
38
- )
39
- t = Thread(target=model.generate, kwargs=generate_kwargs)
40
- t.start()
41
-
42
- partial_message = ""
43
- for new_token in streamer:
44
- if new_token != '<':
45
- partial_message += new_token
46
- yield partial_message
47
-
48
-
49
- gr.ChatInterface(predict).queue().launch()
 
1
  import gradio as gr
 
 
 
2
 
3
+ gr.Interface(fn=predict, inputs="text", outputs="text").launch()