alaamostafa's picture
Update app.py
0b050b9 verified
import gradio as gr
from transformers import pipeline
# Set up the baseline QA model.
# This model is pre-finetuned on SQuAD and is known to provide reasonable answers.
model_id = "distilbert-base-uncased-distilled-squad"
# Create a question-answering pipeline.
# Set device=0 if you want to use a GPU, or remove it (or set device=-1) to run on CPU.
qa_pipeline = pipeline("question-answering", model=model_id, tokenizer=model_id, device=0)
def answer_question(question, context, max_seq_length, doc_stride, max_answer_length, n_best_size):
"""
This function uses the baseline QA pipeline to extract an answer from the provided context.
Hyperparameters:
- max_seq_length: Maximum tokens for the combined input.
- doc_stride: Overlap between segments when the context is too long.
- max_answer_length: Maximum tokens in the extracted answer.
- n_best_size: Number of candidate answer spans considered.
"""
# Run the pipeline with the provided hyperparameters.
result = qa_pipeline(
question=question,
context=context,
max_length=int(max_seq_length),
doc_stride=int(doc_stride),
max_answer_length=int(max_answer_length),
n_best_size=int(n_best_size)
)
return result.get("answer", "No answer found.")
# Define the Gradio interface.
interface = gr.Interface(
fn=answer_question,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your question here", label="Question"),
gr.Textbox(lines=10, placeholder="Enter the context here", label="Context"),
gr.Slider(minimum=100, maximum=512, value=384, step=1, label="Max Seq Length"),
gr.Slider(minimum=50, maximum=256, value=128, step=1, label="Doc Stride"),
gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Max Answer Length"),
gr.Slider(minimum=1, maximum=50, value=20, step=1, label="N Best Size")
],
outputs=gr.Textbox(label="Answer"),
title="Baseline QA Model Demo",
description=(
"This demo uses the baseline model (distilbert-base-uncased-distilled-squad) for question answering. "
"Enter a question and a context, and adjust the hyperparameters as needed."
)
)
# Launch the interface. When deployed in your HF Space, this app will automatically load.
if __name__ == "__main__":
interface.launch()