Tahsin-Mayeesha's picture
Update
6ac74f2
raw
history blame
1.65 kB
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
def choose_model(model_choice):
if model_choice=="mt5-small":
return "jannatul17/squad-bn-qgen-mt5-small-v1"
elif model_choice=="mt5-base":
return "Tahsin-Mayeesha/squad-bn-mt5-base2"
else :
return "jannatul17/squad-bn-qgen-banglat5-v1"
def generate__questions(model_choice,context,answer):
model_name = choose_model(model_choice)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text='answer: '+answer + ' context: ' + context
text_encoding = tokenizer.encode_plus(
text,return_tensors="pt"
)
model.eval()
generated_ids = model.generate(
input_ids=text_encoding['input_ids'],
attention_mask=text_encoding['attention_mask'],
max_length=64,
num_beams=5,
num_return_sequences=1
)
return tokenizer.decode(generated_ids[0],skip_special_tokens=True,clean_up_tokenization_spaces=True).replace('question: ',' ')
demo = gr.Interface(fn=generate__questions, inputs=[gr.Dropdown(label="Model", choices=["mt5-small","mt5-base","banglat5"],value="banglat5"),
gr.Textbox(label='Context'),
gr.Textbox(label='Answer')] ,
outputs=gr.Textbox(label='Question'),
title="Bangla Question Generation",
description="Get the Question from given Context and an Answer")
demo.launch()