tosin2013's picture
Update app.py
ba9ed60 verified
raw
history blame
4.35 kB
import gradio as gr
from transformers import pipeline
import PyPDF2
import markdown
import matplotlib.pyplot as plt
import io
import base64
# Preload models
models = {
"distilbert-base-uncased-distilled-squad": "distilbert-base-uncased-distilled-squad",
"roberta-base-squad2": "deepset/roberta-base-squad2",
"bert-large-uncased-whole-word-masking-finetuned-squad": "bert-large-uncased-whole-word-masking-finetuned-squad",
"albert-base-v2": "twmkn9/albert-base-v2-squad2",
"xlm-roberta-large-squad2": "deepset/xlm-roberta-large-squad2"
}
loaded_models = {}
def load_model(model_name):
if model_name not in loaded_models:
loaded_models[model_name] = pipeline("question-answering", model=models[model_name])
return loaded_models[model_name]
def generate_score_chart(score):
plt.figure(figsize=(6, 4))
plt.bar(["Confidence Score"], [score], color='skyblue')
plt.ylim(0, 1)
plt.ylabel("Score")
plt.title("Confidence Score")
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
return base64.b64encode(buf.getvalue()).decode()
def generate_report(answer, score, score_explanation, score_chart):
report = f"### Answer:\n\n{answer}\n\n### Confidence Score: {score}\n\n### Score Explanation:\n\n{score_explanation}\n\n![Score Chart](data:image/png;base64,{score_chart})"
return report
def answer_question(model_name, file, question, status):
status = "Loading model..."
model = load_model(model_name)
if file is not None:
file_name = file.name
if file_name.endswith(".pdf"):
pdf_reader = PyPDF2.PdfReader(file)
context = ""
for page_num in range(len(pdf_reader.pages)):
context += pdf_reader.pages[page_num].extract_text()
elif file_name.endswith(".md"):
context = file.read().decode('utf-8')
context = markdown.markdown(context)
else:
context = file.read().decode('utf-8')
else:
context = ""
result = model(question=question, context=context)
answer = result['answer']
score = result['score']
# Generate the score chart
score_chart = generate_score_chart(score)
# Explain score
score_explanation = f"The confidence score ranges from 0 to 1, where a higher score indicates higher confidence in the answer's correctness. In this case, the score is {score:.2f}. A score closer to 1 implies the model is very confident about the answer."
# Generate the report
report = generate_report(answer, f"{score:.2f}", score_explanation, score_chart)
status = "Model loaded"
return answer, f"{score:.2f}", score_explanation, score_chart, report, status
# Define the Gradio interface
with gr.Blocks() as interface:
gr.Markdown(
"""
# Question Answering System
Upload a document (text, PDF, or Markdown) and ask questions to get answers based on the context.
**Supported File Types**: `.txt`, `.pdf`, `.md`
""")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(models.keys()),
label="Select Model",
value="distilbert-base-uncased-distilled-squad"
)
with gr.Row():
file_input = gr.File(label="Upload Document", file_types=["text", "pdf", "markdown"])
question_input = gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question")
with gr.Row():
answer_output = gr.Textbox(label="Answer")
score_output = gr.Textbox(label="Confidence Score")
explanation_output = gr.Textbox(label="Score Explanation")
chart_output = gr.Image(label="Score Chart")
report_output = gr.Markdown(label="Report")
with gr.Row():
submit_button = gr.Button("Submit")
status_output = gr.Markdown(value="")
def on_submit(model_name, file, question):
return answer_question(model_name, file, question, status="Loading model...")
submit_button.click(
on_submit,
inputs=[model_dropdown, file_input, question_input],
outputs=[answer_output, score_output, explanation_output, chart_output, report_output, status_output]
)
if __name__ == "__main__":
interface.launch(share=True)