Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import Chroma | |
import openai | |
import time | |
import logging | |
from datasets import load_dataset | |
from nltk.tokenize import sent_tokenize | |
import nltk | |
from langchain.docstore.document import Document | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Download all required NLTK data upfront | |
nltk.download('punkt') | |
nltk.download('punkt_tab') | |
nltk.download('averaged_perceptron_tagger') | |
nltk.download('stopwords') | |
# Initialize OpenAI API key | |
openai.api_key = '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' # Replace with your API key | |
# Load the ragbench datasets | |
ragbench = {} | |
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']: | |
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset) | |
logger.info(f"Loaded {dataset}") | |
# Initialize with a stronger model | |
model_name = 'sentence-transformers/all-mpnet-base-v2' | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
embedding_model = HuggingFaceEmbeddings(model_name=model_name) | |
embedding_model.client.to(device) | |
# Chunking function | |
def chunk_documents_semantic(documents, max_chunk_size=500): | |
chunks = [] | |
for doc in documents: | |
if isinstance(doc, list): | |
for passage in doc: | |
sentences = sent_tokenize(passage) | |
current_chunk = "" | |
for sentence in sentences: | |
if len(current_chunk) + len(sentence) <= max_chunk_size: | |
current_chunk += sentence + " " | |
else: | |
chunks.append(current_chunk.strip()) | |
current_chunk = sentence + " " | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
else: | |
sentences = sent_tokenize(doc) | |
current_chunk = "" | |
for sentence in sentences: | |
if len(current_chunk) + len(sentence) <= max_chunk_size: | |
current_chunk += sentence + " " | |
else: | |
chunks.append(current_chunk.strip()) | |
current_chunk = sentence + " " | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
return chunks | |
# Process documents and create vectordb | |
documents = [] | |
for dataset_name in ragbench.keys(): | |
for split in ragbench[dataset_name].keys(): | |
original_documents = ragbench[dataset_name][split]['documents'] | |
chunked_documents = chunk_documents_semantic(original_documents) | |
documents.extend([Document(page_content=chunk) for chunk in chunked_documents]) | |
# Initialize vectordb with processed documents | |
vectordb = Chroma.from_documents( | |
documents=documents, | |
embedding=embedding_model, | |
persist_directory='./docs/chroma/' | |
) | |
vectordb.persist() | |
def process_query(query, dataset_choice): | |
try: | |
logger.info(f"Processing query for {dataset_choice}: {query}") | |
relevant_docs = vectordb.max_marginal_relevance_search( | |
query, | |
k=5, | |
fetch_k=10 | |
) | |
context = " ".join([doc.page_content for doc in relevant_docs]) | |
response = openai.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a specialized assistant for the RagBench dataset. Provide precise answers based solely on the given context."}, | |
{"role": "user", "content": f"Dataset: {dataset_choice}\nContext: {context}\nQuestion: {query}\n\nProvide a detailed answer using only the information from the context above."} | |
], | |
max_tokens=300, | |
temperature=0.7, | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
logger.error(f"Error processing query: {str(e)}") | |
return f"Error: {str(e)}" | |
# Create Gradio interface with dataset selection | |
demo = gr.Interface( | |
fn=process_query, | |
inputs=[ | |
gr.Textbox(label="Question", placeholder="Type your question here...", lines=2), | |
gr.Dropdown( | |
choices=list(ragbench.keys()), | |
label="Select Dataset", | |
value="hotpotqa" | |
) | |
], | |
outputs=gr.Textbox(label="Answer", lines=5), | |
title="RagBench Question Answering System", | |
description="Ask questions across different RagBench datasets", | |
examples=[ | |
["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?", "covidqa"], | |
["In what school district is Governor John R. Rogers High School located?", "hotpotqa"], | |
["Is there a functional neural correlate of individual differences in cardiovascular reactivity?", "pubmedqa"] | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |