File size: 3,005 Bytes
71f2227
3e149d5
 
71f2227
 
0a042d8
cdfe192
703a8b7
0a042d8
61a2cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
fbd4cfd
 
 
 
 
9bc6c2f
870b1fe
0ef0e83
fd58ec5
 
0ef0e83
a92a6ee
fbd4cfd
 
3e149d5
 
 
55f3dd8
c669d92
 
 
 
 
 
704dec0
7617e7e
10ac208
08d35ca
703a8b7
db0e2dc
cdfe192
247e692
 
db0e2dc
cdfe192
 
db0e2dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
grammar_tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
grammar_model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector')
import torch
import gradio as gr


# def chat(message, history):
#     history = history if history is not None else []
#     new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
#     bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
#     history = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist()
#     # response = tokenizer.decode(history[0]).replace("<|endoftext|>", "\n")
#     # pretty print last ouput tokens from bot
#     response =  tokenizer.decode(bot_input_ids.shape[-1][0], skip_special_tokens=True)
#     print("The response is ", [response])
#     # history.append((message, response, new_user_input_ids, chat_history_ids))
#     return response, history, feedback(message)


def chat(message, history=[]):
    new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
    if  len(history) > 0:
        last_set_of_ids = history[len(history)-1][3]
        bot_input_ids = torch.cat([last_set_of_ids, new_user_input_ids], dim=-1) 
    else:
        print("HERE WE GO! ", new_user_input_ids)
        bot_input_ids = new_user_input_ids
    chat_history_ids = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id)
    response_ids = chat_history_ids[:, bot_input_ids.shape[-1]:][0]
    response = tokenizer.decode(response_ids, skip_special_tokens=True)
    # response = tokenizer.decode(response_ids).replace("<|endoftext|>", "")
    bot_input_ids = torch.cat([bot_input_ids, response_ids], dim=-1)
    history.push((message, response, bot_input_ids))
    return response, history, feedback(message)


def feedback(text):
    num_return_sequences=1
    batch =  grammar_tokenizer([text],truncation=True,padding='max_length',max_length=64, return_tensors="pt")
    corrections= grammar_model.generate(**batch,max_length=64,num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
    print("The corrections are: ", corrections)
    if len(corrections) == 0:
        feedback = f'Looks good! Keep up the good work'
    else:
        suggestion = grammar_tokenizer.batch_decode(corrections[0], skip_special_tokens=True)
        suggestion = [sug for sug in suggestion if '<' not in sug]
        feedback = f'\'{" ".join(suggestion)}\' might be a little better'
    return feedback

iface = gr.Interface(
    chat,
    ["text", "state"],
    ["chatbot", "state", "text"],
    allow_screenshot=False,
    allow_flagging="never",
)
iface.launch()