Spaces:
Sleeping
Sleeping
File size: 5,128 Bytes
1d91ffa 4beb772 1d91ffa 9ead98b 4beb772 8dfd657 1d91ffa 4beb772 8dfd657 bc10f71 4beb772 8dfd657 1d91ffa 8dfd657 1d91ffa 8dfd657 4beb772 8dfd657 1d91ffa 8dfd657 1d91ffa 4beb772 1d91ffa 8dfd657 1d91ffa 8dfd657 1d91ffa 8dfd657 1d91ffa 8dfd657 1d91ffa 8dfd657 1d91ffa 8dfd657 1d91ffa 8dfd657 1d91ffa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
import openai
import time
import logging
from datasets import load_dataset
from nltk.tokenize import sent_tokenize
import nltk
from langchain.docstore.document import Document
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Download all required NLTK data upfront
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')
# Initialize OpenAI API key
openai.api_key = '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' # Replace with your API key
# Load the ragbench datasets
ragbench = {}
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']:
ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
logger.info(f"Loaded {dataset}")
# Initialize with a stronger model
model_name = 'sentence-transformers/all-mpnet-base-v2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = HuggingFaceEmbeddings(model_name=model_name)
embedding_model.client.to(device)
# Chunking function
def chunk_documents_semantic(documents, max_chunk_size=500):
chunks = []
for doc in documents:
if isinstance(doc, list):
for passage in doc:
sentences = sent_tokenize(passage)
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= max_chunk_size:
current_chunk += sentence + " "
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
else:
sentences = sent_tokenize(doc)
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= max_chunk_size:
current_chunk += sentence + " "
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
# Process documents and create vectordb
documents = []
for dataset_name in ragbench.keys():
for split in ragbench[dataset_name].keys():
original_documents = ragbench[dataset_name][split]['documents']
chunked_documents = chunk_documents_semantic(original_documents)
documents.extend([Document(page_content=chunk) for chunk in chunked_documents])
# Initialize vectordb with processed documents
vectordb = Chroma.from_documents(
documents=documents,
embedding=embedding_model,
persist_directory='./docs/chroma/'
)
vectordb.persist()
def process_query(query, dataset_choice):
try:
logger.info(f"Processing query for {dataset_choice}: {query}")
relevant_docs = vectordb.max_marginal_relevance_search(
query,
k=5,
fetch_k=10
)
context = " ".join([doc.page_content for doc in relevant_docs])
response = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a specialized assistant for the RagBench dataset. Provide precise answers based solely on the given context."},
{"role": "user", "content": f"Dataset: {dataset_choice}\nContext: {context}\nQuestion: {query}\n\nProvide a detailed answer using only the information from the context above."}
],
max_tokens=300,
temperature=0.7,
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error processing query: {str(e)}")
return f"Error: {str(e)}"
# Create Gradio interface with dataset selection
demo = gr.Interface(
fn=process_query,
inputs=[
gr.Textbox(label="Question", placeholder="Type your question here...", lines=2),
gr.Dropdown(
choices=list(ragbench.keys()),
label="Select Dataset",
value="hotpotqa"
)
],
outputs=gr.Textbox(label="Answer", lines=5),
title="RagBench Question Answering System",
description="Ask questions across different RagBench datasets",
examples=[
["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?", "covidqa"],
["In what school district is Governor John R. Rogers High School located?", "hotpotqa"],
["Is there a functional neural correlate of individual differences in cardiovascular reactivity?", "pubmedqa"]
]
)
if __name__ == "__main__":
demo.launch(debug=True)
|