File size: 5,128 Bytes
1d91ffa
 
 
 
 
 
 
 
 
 
4beb772
1d91ffa
 
 
 
 
9ead98b
 
 
 
 
 
4beb772
 
 
8dfd657
 
 
 
 
1d91ffa
4beb772
8dfd657
bc10f71
 
 
 
4beb772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dfd657
1d91ffa
8dfd657
1d91ffa
8dfd657
 
4beb772
 
8dfd657
1d91ffa
8dfd657
1d91ffa
 
4beb772
1d91ffa
8dfd657
 
1d91ffa
 
 
 
 
8dfd657
1d91ffa
 
 
8dfd657
1d91ffa
8dfd657
1d91ffa
 
 
8dfd657
 
 
 
 
1d91ffa
 
8dfd657
 
 
1d91ffa
8dfd657
 
 
1d91ffa
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
import openai
import time
import logging
from datasets import load_dataset
from nltk.tokenize import sent_tokenize
import nltk
from langchain.docstore.document import Document

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Download all required NLTK data upfront
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')

# Initialize OpenAI API key
openai.api_key = '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'  # Replace with your API key

# Load the ragbench datasets
ragbench = {}
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']:
    ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
    logger.info(f"Loaded {dataset}")

# Initialize with a stronger model
model_name = 'sentence-transformers/all-mpnet-base-v2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = HuggingFaceEmbeddings(model_name=model_name)
embedding_model.client.to(device)

# Chunking function
def chunk_documents_semantic(documents, max_chunk_size=500):
    chunks = []
    for doc in documents:
        if isinstance(doc, list):
            for passage in doc:
                sentences = sent_tokenize(passage)
                current_chunk = ""
                for sentence in sentences:
                    if len(current_chunk) + len(sentence) <= max_chunk_size:
                        current_chunk += sentence + " "
                    else:
                        chunks.append(current_chunk.strip())
                        current_chunk = sentence + " "
                if current_chunk:
                    chunks.append(current_chunk.strip())
        else:
            sentences = sent_tokenize(doc)
            current_chunk = ""
            for sentence in sentences:
                if len(current_chunk) + len(sentence) <= max_chunk_size:
                    current_chunk += sentence + " "
                else:
                    chunks.append(current_chunk.strip())
                    current_chunk = sentence + " "
            if current_chunk:
                chunks.append(current_chunk.strip())
    return chunks

# Process documents and create vectordb
documents = []
for dataset_name in ragbench.keys():
    for split in ragbench[dataset_name].keys():
        original_documents = ragbench[dataset_name][split]['documents']
        chunked_documents = chunk_documents_semantic(original_documents)
        documents.extend([Document(page_content=chunk) for chunk in chunked_documents])

# Initialize vectordb with processed documents
vectordb = Chroma.from_documents(
    documents=documents,
    embedding=embedding_model,
    persist_directory='./docs/chroma/'
)
vectordb.persist()

def process_query(query, dataset_choice):
    try:
        logger.info(f"Processing query for {dataset_choice}: {query}")
        
        relevant_docs = vectordb.max_marginal_relevance_search(
            query, 
            k=5,
            fetch_k=10
        )
        
        context = " ".join([doc.page_content for doc in relevant_docs])
        
        response = openai.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a specialized assistant for the RagBench dataset. Provide precise answers based solely on the given context."},
                {"role": "user", "content": f"Dataset: {dataset_choice}\nContext: {context}\nQuestion: {query}\n\nProvide a detailed answer using only the information from the context above."}
            ],
            max_tokens=300,
            temperature=0.7,
        )
        
        return response.choices[0].message.content.strip()
        
    except Exception as e:
        logger.error(f"Error processing query: {str(e)}")
        return f"Error: {str(e)}"

# Create Gradio interface with dataset selection
demo = gr.Interface(
    fn=process_query,
    inputs=[
        gr.Textbox(label="Question", placeholder="Type your question here...", lines=2),
        gr.Dropdown(
            choices=list(ragbench.keys()),
            label="Select Dataset",
            value="hotpotqa"
        )
    ],
    outputs=gr.Textbox(label="Answer", lines=5),
    title="RagBench Question Answering System",
    description="Ask questions across different RagBench datasets",
    examples=[
        ["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?", "covidqa"],
        ["In what school district is Governor John R. Rogers High School located?", "hotpotqa"],
        ["Is there a functional neural correlate of individual differences in cardiovascular reactivity?", "pubmedqa"]
    ]
)

if __name__ == "__main__":
    demo.launch(debug=True)