Gradio / app.py
ajalisatgi's picture
Update app.py
9ead98b verified
raw
history blame
5.13 kB
import torch
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
import openai
import time
import logging
from datasets import load_dataset
from nltk.tokenize import sent_tokenize
import nltk
from langchain.docstore.document import Document
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Download all required NLTK data upfront
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')
# Initialize OpenAI API key
openai.api_key = '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' # Replace with your API key
# Load the ragbench datasets
ragbench = {}
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']:
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
logger.info(f"Loaded {dataset}")
# Initialize with a stronger model
model_name = 'sentence-transformers/all-mpnet-base-v2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = HuggingFaceEmbeddings(model_name=model_name)
embedding_model.client.to(device)
# Chunking function
def chunk_documents_semantic(documents, max_chunk_size=500):
chunks = []
for doc in documents:
if isinstance(doc, list):
for passage in doc:
sentences = sent_tokenize(passage)
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= max_chunk_size:
current_chunk += sentence + " "
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
else:
sentences = sent_tokenize(doc)
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= max_chunk_size:
current_chunk += sentence + " "
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
# Process documents and create vectordb
documents = []
for dataset_name in ragbench.keys():
for split in ragbench[dataset_name].keys():
original_documents = ragbench[dataset_name][split]['documents']
chunked_documents = chunk_documents_semantic(original_documents)
documents.extend([Document(page_content=chunk) for chunk in chunked_documents])
# Initialize vectordb with processed documents
vectordb = Chroma.from_documents(
documents=documents,
embedding=embedding_model,
persist_directory='./docs/chroma/'
)
vectordb.persist()
def process_query(query, dataset_choice):
try:
logger.info(f"Processing query for {dataset_choice}: {query}")
relevant_docs = vectordb.max_marginal_relevance_search(
query,
k=5,
fetch_k=10
)
context = " ".join([doc.page_content for doc in relevant_docs])
response = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a specialized assistant for the RagBench dataset. Provide precise answers based solely on the given context."},
{"role": "user", "content": f"Dataset: {dataset_choice}\nContext: {context}\nQuestion: {query}\n\nProvide a detailed answer using only the information from the context above."}
],
max_tokens=300,
temperature=0.7,
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error processing query: {str(e)}")
return f"Error: {str(e)}"
# Create Gradio interface with dataset selection
demo = gr.Interface(
fn=process_query,
inputs=[
gr.Textbox(label="Question", placeholder="Type your question here...", lines=2),
gr.Dropdown(
choices=list(ragbench.keys()),
label="Select Dataset",
value="hotpotqa"
)
],
outputs=gr.Textbox(label="Answer", lines=5),
title="RagBench Question Answering System",
description="Ask questions across different RagBench datasets",
examples=[
["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?", "covidqa"],
["In what school district is Governor John R. Rogers High School located?", "hotpotqa"],
["Is there a functional neural correlate of individual differences in cardiovascular reactivity?", "pubmedqa"]
]
)
if __name__ == "__main__":
demo.launch(debug=True)