Spaces:
Paused
Paused
import gradio as gr | |
from transformers import BertForQuestionAnswering | |
from transformers import BertTokenizerFast | |
import torch | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | |
model = BertForQuestionAnswering.from_pretrained("CountingMstar/ai-tutor-bert-model").to(device) | |
def get_prediction(context, question): | |
inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device) | |
outputs = model(**inputs) | |
answer_start = torch.argmax(outputs.start_logits) | |
answer_end = torch.argmax(outputs.end_logits) + 1 | |
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) | |
return answer | |
def question_answer(context, question): | |
prediction = get_prediction(context, question) | |
return prediction | |
def submit(context, question): | |
answer = question_answer(context, question) | |
return answer | |
examples = [ | |
["A large language model is...", "What is a large language model?"], | |
["Feature engineering is the process of...", "What is Feature engineering?"], | |
["Attention mechanism calculates soft weights...", "What is Attention mechanism?"] | |
] | |
input_textbox = gr.Textbox("Context", placeholder="Enter context here") | |
question_textbox = gr.Textbox("Question", placeholder="Enter question here") | |
input_section = gr.Row([input_textbox, question_textbox]) | |
markdown_text = """ | |
## Example Questions | |
Use the examples below or enter your own context and question. | |
""" | |
iface = gr.Interface( | |
fn=submit, | |
inputs=input_section, | |
outputs=gr.Textbox("Answer"), | |
examples=examples, | |
live=True, # Set live to True to use the submit button | |
title="BERT Question Answering" | |
) | |
iface.launch() |