|
import gradio as gr |
|
from transformers import pipeline |
|
import PyPDF2 |
|
import markdown |
|
import matplotlib.pyplot as plt |
|
import io |
|
import base64 |
|
import torch |
|
from fpdf import FPDF |
|
import os |
|
import tempfile |
|
import glob |
|
|
|
|
|
models = { |
|
"distilbert-base-uncased-distilled-squad": "distilbert-base-uncased-distilled-squad", |
|
"roberta-base-squad2": "deepset/roberta-base-squad2", |
|
"bert-large-uncased-whole-word-masking-finetuned-squad": "bert-large-uncased-whole-word-masking-finetuned-squad", |
|
"albert-base-v2": "twmkn9/albert-base-v2-squad2", |
|
"xlm-roberta-large-squad2": "deepset/xlm-roberta-large-squad2" |
|
} |
|
|
|
loaded_models = {} |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def load_model(model_name): |
|
if model_name not in loaded_models: |
|
loaded_models[model_name] = pipeline("question-answering", model=models[model_name], device=0 if torch.cuda.is_available() else -1) |
|
return loaded_models[model_name] |
|
|
|
def generate_score_chart(score): |
|
plt.figure(figsize=(6, 4)) |
|
plt.bar(["Confidence Score"], [score], color='skyblue') |
|
plt.ylim(0, 1) |
|
plt.ylabel("Score") |
|
plt.title("Confidence Score") |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
buf.seek(0) |
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
def highlight_relevant_text(context, start, end): |
|
highlighted_text = ( |
|
context[:start] + |
|
'<mark style="background-color: yellow;">' + |
|
context[start:end] + |
|
'</mark>' + |
|
context[end:] |
|
) |
|
return highlighted_text |
|
|
|
def find_system_font(): |
|
|
|
font_dirs = ["/usr/share/fonts", "/usr/local/share/fonts"] |
|
for font_dir in font_dirs: |
|
ttf_files = glob.glob(os.path.join(font_dir, "**/NotoSans*.ttf"), recursive=True) |
|
if ttf_files: |
|
return ttf_files[0] |
|
raise FileNotFoundError("No suitable TTF font file found in system font directories.") |
|
|
|
def generate_pdf_report(question, answer, score, score_explanation, score_chart, highlighted_context): |
|
pdf = FPDF() |
|
pdf.add_page() |
|
|
|
|
|
font_path = find_system_font() |
|
pdf.add_font("NotoSans", "", font_path) |
|
pdf.set_font("NotoSans", size=12) |
|
|
|
pdf.multi_cell(0, 10, f"Question: {question}") |
|
pdf.ln() |
|
|
|
pdf.set_font("NotoSans", size=12) |
|
pdf.multi_cell(0, 10, f"Answer: {answer}") |
|
pdf.ln() |
|
|
|
pdf.set_font("NotoSans", size=12) |
|
pdf.multi_cell(0, 10, f"Confidence Score: {score}") |
|
pdf.ln() |
|
|
|
pdf.set_font("NotoSans", size=12) |
|
pdf.multi_cell(0, 10, f"Score Explanation: {score_explanation}") |
|
pdf.ln() |
|
|
|
pdf.set_font("NotoSans", size=12) |
|
pdf.multi_cell(0, 10, "Highlighted Context:") |
|
pdf.ln() |
|
pdf.set_font("NotoSans", size=10) |
|
pdf.multi_cell(0, 10, highlighted_context) |
|
pdf.ln() |
|
|
|
|
|
score_chart_image = io.BytesIO(base64.b64decode(score_chart)) |
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmpfile: |
|
tmpfile.write(score_chart_image.read()) |
|
tmpfile.flush() |
|
tmpfile.close() |
|
pdf.image(tmpfile.name, x=10, y=pdf.get_y(), w=100) |
|
|
|
|
|
pdf_output = io.BytesIO() |
|
pdf.output(pdf_output) |
|
pdf_output.seek(0) |
|
|
|
|
|
os.remove(tmpfile.name) |
|
|
|
return pdf_output |
|
|
|
def answer_question(model_name, file, question, status): |
|
status = "Loading model..." |
|
model = load_model(model_name) |
|
|
|
if file is not None: |
|
file_name = file.name |
|
if file_name.endswith(".pdf"): |
|
pdf_reader = PyPDF2.PdfReader(file) |
|
context = "" |
|
for page_num in range(len(pdf_reader.pages)): |
|
context += pdf_reader.pages[page_num].extract_text() |
|
elif file_name.endswith(".md"): |
|
context = file.read().decode('utf-8') |
|
context = markdown.markdown(context) |
|
else: |
|
context = file.read().decode('utf-8') |
|
else: |
|
context = "" |
|
|
|
result = model(question=question, context=context) |
|
answer = result['answer'] |
|
score = result['score'] |
|
start = result['start'] |
|
end = result['end'] |
|
|
|
|
|
highlighted_context = highlight_relevant_text(context, start, end) |
|
|
|
|
|
score_chart = generate_score_chart(score) |
|
|
|
|
|
score_explanation = f"The confidence score ranges from 0 to 1, where a higher score indicates higher confidence in the answer's correctness. In this case, the score is {score:.2f}. A score closer to 1 implies the model is very confident about the answer." |
|
|
|
|
|
pdf_report = generate_pdf_report(question, answer, f"{score:.2f}", score_explanation, score_chart, highlighted_context) |
|
|
|
status = "Model loaded" |
|
return highlighted_context, f"{score:.2f}", score_explanation, score_chart, pdf_report, status |
|
|
|
|
|
with gr.Blocks() as interface: |
|
gr.Markdown( |
|
""" |
|
# Question Answering System |
|
Upload a document (text, PDF, or Markdown) and ask questions to get answers based on the context. |
|
|
|
**Supported File Types**: `.txt`, `.pdf`, `.md` |
|
""") |
|
|
|
with gr.Row(): |
|
model_dropdown = gr.Dropdown( |
|
choices=list(models.keys()), |
|
label="Select Model", |
|
value="distilbert-base-uncased-distilled-squad" |
|
) |
|
|
|
with gr.Row(): |
|
file_input = gr.File(label="Upload Document", file_types=["text", "pdf", "markdown"]) |
|
question_input = gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question") |
|
|
|
with gr.Row(): |
|
answer_output = gr.HTML(label="Highlighted Answer") |
|
score_output = gr.Textbox(label="Confidence Score") |
|
explanation_output = gr.Textbox(label="Score Explanation") |
|
chart_output = gr.Image(label="Score Chart") |
|
pdf_output = gr.File(label="Download PDF Report") |
|
|
|
with gr.Row(): |
|
submit_button = gr.Button("Submit") |
|
|
|
status_output = gr.Markdown(value="") |
|
|
|
def on_submit(model_name, file, question): |
|
return answer_question(model_name, file, question, status="Loading model...") |
|
|
|
submit_button.click( |
|
on_submit, |
|
inputs=[model_dropdown, file_input, question_input], |
|
outputs=[answer_output, score_output, explanation_output, chart_output, pdf_output, status_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch(share=True) |
|
|