File size: 3,099 Bytes
1d91ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dfd657
 
 
 
 
1d91ffa
8dfd657
 
bc10f71
 
 
 
8dfd657
1d91ffa
8dfd657
1d91ffa
8dfd657
 
 
 
 
 
1d91ffa
8dfd657
1d91ffa
 
 
 
8dfd657
 
1d91ffa
 
 
 
 
8dfd657
1d91ffa
 
 
8dfd657
1d91ffa
8dfd657
1d91ffa
 
 
8dfd657
 
 
 
 
1d91ffa
 
8dfd657
 
 
1d91ffa
8dfd657
 
 
1d91ffa
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import gradio as gr
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
import openai
import time
import logging
from datasets import load_dataset
from nltk.tokenize import sent_tokenize
import nltk

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load the ragbench datasets
ragbench = {}
for dataset in ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 'tatqa', 'techqa']:
    ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset)
    logger.info(f"Loaded {dataset}")

# Initialize with a stronger model for better semantic understanding
model_name = 'sentence-transformers/all-mpnet-base-v2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = HuggingFaceEmbeddings(model_name=model_name)
embedding_model.client.to(device)

def process_query(query, dataset_choice):
    try:
        logger.info(f"Processing query for {dataset_choice}: {query}")
        
        # Get relevant documents specific to the chosen dataset
        relevant_docs = vectordb.max_marginal_relevance_search(
            query, 
            k=5,  # Top 5 most relevant documents
            fetch_k=10  # Fetch top 10 then select most diverse 5
        )
        
        context = " ".join([doc.page_content for doc in relevant_docs])
        
        response = openai.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "You are a specialized assistant for the RagBench dataset. Provide precise answers based solely on the given context."},
                {"role": "user", "content": f"Dataset: {dataset_choice}\nContext: {context}\nQuestion: {query}\n\nProvide a detailed answer using only the information from the context above."}
            ],
            max_tokens=300,
            temperature=0.7,
        )
        
        return response.choices[0].message.content.strip()
        
    except Exception as e:
        logger.error(f"Error processing query: {str(e)}")
        return f"Error: {str(e)}"

# Create Gradio interface with dataset selection
demo = gr.Interface(
    fn=process_query,
    inputs=[
        gr.Textbox(label="Question", placeholder="Type your question here...", lines=2),
        gr.Dropdown(
            choices=list(ragbench.keys()),
            label="Select Dataset",
            value="hotpotqa"
        )
    ],
    outputs=gr.Textbox(label="Answer", lines=5),
    title="RagBench Question Answering System",
    description="Ask questions across different RagBench datasets",
    examples=[
        ["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?", "covidqa"],
        ["In what school district is Governor John R. Rogers High School located?", "hotpotqa"],
        ["Is there a functional neural correlate of individual differences in cardiovascular reactivity?", "pubmedqa"]
    ]
)

if __name__ == "__main__":
    demo.launch(debug=True)