Spaces:
Build error
Build error
File size: 2,152 Bytes
71f2227 3e149d5 71f2227 0a042d8 cdfe192 703a8b7 0a042d8 61a2cb3 fbd4cfd 0d07fde fbd4cfd 870b1fe 482fbde fd58ec5 2dfc510 ed80cd3 3e149d5 55f3dd8 c669d92 f13ff81 691bbb0 c669d92 482fbde c669d92 f13ff81 08d35ca 703a8b7 db0e2dc cdfe192 6bbb073 337fe11 2ddaf14 6bbb073 2ddaf14 337fe11 cdfe192 db0e2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
grammar_tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
grammar_model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector')
import torch
import gradio as gr
def chat(message, history=[]):
new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
if len(history) > 0:
last_set_of_ids = history[len(history)-1][2]
bot_input_ids = torch.cat([last_set_of_ids, new_user_input_ids], dim=-1)
else:
bot_input_ids = new_user_input_ids
chat_history_ids = model.generate(bot_input_ids, max_length=5000, pad_token_id=tokenizer.eos_token_id)
response_ids = chat_history_ids[:, bot_input_ids.shape[-1]:][0]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
history.append((message, response, chat_history_ids))
return history, history, feedback(message)
def feedback(text):
num_return_sequences=1
batch = grammar_tokenizer([text],truncation=True,padding='max_length',max_length=64, return_tensors="pt")
corrections = grammar_model.generate(**batch,max_length=64,num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
corrected_text = grammar_tokenizer.decode(corrections[0], clean_up_tokenization_spaces=True, skip_special_tokens=True)
print("The corrections are: ", corrections)
if corrected_text == text:
feedback = f'Looks good! Keep up the good work'
else:
feedback = f'\'{corrected_text}\' might be a little better'
return feedback
iface = gr.Interface(
chat,
[gr.Textbox(label="Send messages here"), "state"],
[gr.Chatbot(color_map=("green", "gray"), label='Conversation'), "state", gr.Textbox(
label="Feedback",
lines=1
)],
allow_screenshot=False,
allow_flagging="never",
)
iface.launch()
|