Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import Chroma | |
import openai | |
import time | |
import logging | |
from datasets import load_dataset | |
from nltk.tokenize import sent_tokenize | |
import nltk | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load the ragbench datasets | |
ragbench = {} | |
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']: | |
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset) | |
logger.info(f"Loaded {dataset}") | |
# Initialize with a stronger model for better semantic understanding | |
model_name = 'sentence-transformers/all-mpnet-base-v2' | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
embedding_model = HuggingFaceEmbeddings(model_name=model_name) | |
embedding_model.client.to(device) | |
def process_query(query, dataset_choice): | |
try: | |
logger.info(f"Processing query for {dataset_choice}: {query}") | |
# Get relevant documents specific to the chosen dataset | |
relevant_docs = vectordb.max_marginal_relevance_search( | |
query, | |
k=5, # Top 5 most relevant documents | |
fetch_k=10 # Fetch top 10 then select most diverse 5 | |
) | |
context = " ".join([doc.page_content for doc in relevant_docs]) | |
response = openai.chat.completions.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", "content": "You are a specialized assistant for the RagBench dataset. Provide precise answers based solely on the given context."}, | |
{"role": "user", "content": f"Dataset: {dataset_choice}\nContext: {context}\nQuestion: {query}\n\nProvide a detailed answer using only the information from the context above."} | |
], | |
max_tokens=300, | |
temperature=0.7, | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
logger.error(f"Error processing query: {str(e)}") | |
return f"Error: {str(e)}" | |
# Create Gradio interface with dataset selection | |
demo = gr.Interface( | |
fn=process_query, | |
inputs=[ | |
gr.Textbox(label="Question", placeholder="Type your question here...", lines=2), | |
gr.Dropdown( | |
choices=list(ragbench.keys()), | |
label="Select Dataset", | |
value="hotpotqa" | |
) | |
], | |
outputs=gr.Textbox(label="Answer", lines=5), | |
title="RagBench Question Answering System", | |
description="Ask questions across different RagBench datasets", | |
examples=[ | |
["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?", "covidqa"], | |
["In what school district is Governor John R. Rogers High School located?", "hotpotqa"], | |
["Is there a functional neural correlate of individual differences in cardiovascular reactivity?", "pubmedqa"] | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |