Spaces:
Build error
Build error
File size: 2,864 Bytes
71f2227 3e149d5 71f2227 0a042d8 cdfe192 703a8b7 0a042d8 61a2cb3 fbd4cfd 0d07fde fbd4cfd 9bc6c2f 870b1fe 0d07fde fd58ec5 2dfc510 ed80cd3 3e149d5 55f3dd8 c669d92 704dec0 7617e7e 10ac208 08d35ca 703a8b7 db0e2dc cdfe192 247e692 db0e2dc cdfe192 db0e2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
grammar_tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
grammar_model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector')
import torch
import gradio as gr
# def chat(message, history):
# history = history if history is not None else []
# new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
# bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
# history = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist()
# # response = tokenizer.decode(history[0]).replace("<|endoftext|>", "\n")
# # pretty print last ouput tokens from bot
# response = tokenizer.decode(bot_input_ids.shape[-1][0], skip_special_tokens=True)
# print("The response is ", [response])
# # history.append((message, response, new_user_input_ids, chat_history_ids))
# return response, history, feedback(message)
def chat(message, history=[]):
new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
if len(history) > 0:
last_set_of_ids = history[len(history)-1][2]
bot_input_ids = torch.cat([last_set_of_ids, new_user_input_ids], dim=-1)
else:
print("HERE WE GO! ", new_user_input_ids)
bot_input_ids = new_user_input_ids
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
response_ids = chat_history_ids[:, bot_input_ids.shape[-1]:][0]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
history.append((message, response, chat_history_ids))
return history, history, feedback(message)
def feedback(text):
num_return_sequences=1
batch = grammar_tokenizer([text],truncation=True,padding='max_length',max_length=64, return_tensors="pt")
corrections= grammar_model.generate(**batch,max_length=64,num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
print("The corrections are: ", corrections)
if len(corrections) == 0:
feedback = f'Looks good! Keep up the good work'
else:
suggestion = grammar_tokenizer.batch_decode(corrections[0], skip_special_tokens=True)
suggestion = [sug for sug in suggestion if '<' not in sug]
feedback = f'\'{" ".join(suggestion)}\' might be a little better'
return feedback
iface = gr.Interface(
chat,
["text", "state"],
["chatbot", "state", "text"],
allow_screenshot=False,
allow_flagging="never",
)
iface.launch()
|