Hann99 commited on
Commit
47b6faf
·
1 Parent(s): b183d3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -12
app.py CHANGED
@@ -1,13 +1,46 @@
1
- from transformers import PegasusForConditionalGeneration, PegasusTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as grad
3
- mdl_name = "google/pegasus-xsum"
4
- pegasus_tkn = PegasusTokenizer.from_pretrained(mdl_name)
5
- mdl = PegasusForConditionalGeneration.from_pretrained(mdl_name)
6
- def summarize(text):
7
- tokens = pegasus_tkn(text, truncation=True, padding="longest", return_tensors="pt")
8
- txt_summary = mdl.generate(**tokens)
9
- response = pegasus_tkn.batch_decode(txt_summary, skip_special_tokens=True)
10
- return response
11
- txt=grad.Textbox(lines=10, label="English", placeholder="English Text here")
12
- out=grad.Textbox(lines=10, label="Summary")
13
- grad.Interface(summarize, inputs=txt, outputs=out).launch()
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
2
+ import torch
3
+ chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
4
+ mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
5
+ #chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
6
+ #mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
7
+ def converse(user_input, chat_history=[]):
8
+ user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
9
+ # keep history in the tensor
10
+ bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
11
+ # get response
12
+ chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
13
+ print (chat_history)
14
+ response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
15
+ print("starting to print response")
16
+ print(response)
17
+ # html for display
18
+ html = "<div class='mybot'>"
19
+ for x, mesg in enumerate(response):
20
+ if x%2!=0 :
21
+ mesg="Alicia:"+mesg
22
+ clazz="alicia"
23
+ else :
24
+ clazz="user"
25
+ print("value of x")
26
+ print(x)
27
+ print("message")
28
+ print (mesg)
29
+ html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
30
+ html += "</div>"
31
+ print(html)
32
+ return html, chat_history
33
  import gradio as grad
34
+ css = """
35
+ .mychat {display:flex;flex-direction:column}
36
+ .mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
37
+ .mesg.user {background-color:lightblue;color:white}
38
+ .mesg.alicia {background-color:orange;color:white,align-self:self-end}
39
+ .footer {display:none !important}
40
+ """
41
+ text=grad.inputs.Textbox(placeholder="Lets chat")
42
+ grad.Interface(fn=converse,
43
+ theme="default",
44
+ inputs=[text, "state"],
45
+ outputs=["html", "state"],
46
+ css=css).launch()