Gradio / app.py
ajalisatgi's picture
Update app.py
8dfd657 verified
raw
history blame
3.1 kB
import torch
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
import openai
import time
import logging
from datasets import load_dataset
from nltk.tokenize import sent_tokenize
import nltk
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load the ragbench datasets
ragbench = {}
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']:
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
logger.info(f"Loaded {dataset}")
# Initialize with a stronger model for better semantic understanding
model_name = 'sentence-transformers/all-mpnet-base-v2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = HuggingFaceEmbeddings(model_name=model_name)
embedding_model.client.to(device)
def process_query(query, dataset_choice):
try:
logger.info(f"Processing query for {dataset_choice}: {query}")
# Get relevant documents specific to the chosen dataset
relevant_docs = vectordb.max_marginal_relevance_search(
query,
k=5, # Top 5 most relevant documents
fetch_k=10 # Fetch top 10 then select most diverse 5
)
context = " ".join([doc.page_content for doc in relevant_docs])
response = openai.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "You are a specialized assistant for the RagBench dataset. Provide precise answers based solely on the given context."},
{"role": "user", "content": f"Dataset: {dataset_choice}\nContext: {context}\nQuestion: {query}\n\nProvide a detailed answer using only the information from the context above."}
],
max_tokens=300,
temperature=0.7,
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error processing query: {str(e)}")
return f"Error: {str(e)}"
# Create Gradio interface with dataset selection
demo = gr.Interface(
fn=process_query,
inputs=[
gr.Textbox(label="Question", placeholder="Type your question here...", lines=2),
gr.Dropdown(
choices=list(ragbench.keys()),
label="Select Dataset",
value="hotpotqa"
)
],
outputs=gr.Textbox(label="Answer", lines=5),
title="RagBench Question Answering System",
description="Ask questions across different RagBench datasets",
examples=[
["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?", "covidqa"],
["In what school district is Governor John R. Rogers High School located?", "hotpotqa"],
["Is there a functional neural correlate of individual differences in cardiovascular reactivity?", "pubmedqa"]
]
)
if __name__ == "__main__":
demo.launch(debug=True)