File size: 3,691 Bytes
1d91ffa
 
 
4d16da0
3fcfa56
a130567
3fcfa56
 
 
2e77d5f
 
1d91ffa
 
 
2e77d5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fcfa56
2e77d5f
 
 
 
 
 
 
9ed9be5
a48a101
3fcfa56
1d91ffa
2e77d5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a48a101
2e77d5f
 
 
a48a101
2e77d5f
 
 
 
 
 
3fcfa56
a48a101
2e77d5f
1d91ffa
 
2e77d5f
1d91ffa
2e77d5f
1d91ffa
 
a48a101
2e77d5f
a48a101
 
 
 
 
 
3fcfa56
 
 
 
2e77d5f
 
1d91ffa
 
 
2e77d5f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import openai
from datasets import load_dataset
import logging
import time
from langchain_community.embeddings import HuggingFaceEmbeddings
import torch
import psutil
import GPUtil

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize OpenAI API key
openai.api_key = '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'

# Initialize with E5 embedding model
model_name = 'intfloat/e5-base-v2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = HuggingFaceEmbeddings(model_name=model_name)
embedding_model.client.to(device)

# Load datasets
datasets = {}
dataset_names = ['covidqa', 'hotpotqa', 'pubmedqa']  # Starting with key datasets

for name in dataset_names:
    datasets[name] = load_dataset("rungalileo/ragbench", name, split='train')
    logger.info(f"Loaded {name}")

def get_system_metrics():
    metrics = {
        'cpu_percent': psutil.cpu_percent(),
        'memory_percent': psutil.virtual_memory().percent,
        'gpu_util': GPUtil.getGPUs()[0].load * 100 if torch.cuda.is_available() else 0,
        'gpu_memory': GPUtil.getGPUs()[0].memoryUtil * 100 if torch.cuda.is_available() else 0
    }
    return metrics

def process_query(query, dataset_choice="all"):
    start_time = time.time()
    try:
        relevant_contexts = []
        search_datasets = [dataset_choice] if dataset_choice != "all" else datasets.keys()
        
        for dataset_name in search_datasets:
            if dataset_name in datasets:
                for doc in datasets[dataset_name]['documents']:
                    if any(keyword.lower() in doc.lower() for keyword in query.split()):
                        relevant_contexts.append((doc, dataset_name))
        
        context_info = f"From {relevant_contexts[0][1]}: {relevant_contexts[0][0]}" if relevant_contexts else "Searching across datasets..."
        
        response = openai.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a knowledgeable expert using E5 embeddings for precise information retrieval."},
                {"role": "user", "content": f"Context: {context_info}\nQuestion: {query}"}
            ],
            max_tokens=300,
            temperature=0.7,
        )
        
        # Get performance metrics
        metrics = get_system_metrics()
        metrics['processing_time'] = time.time() - start_time
        
        metrics_display = f"""
        Processing Time: {metrics['processing_time']:.2f}s
        CPU Usage: {metrics['cpu_percent']}%
        Memory Usage: {metrics['memory_percent']}%
        GPU Utilization: {metrics['gpu_util']:.1f}%
        GPU Memory: {metrics['gpu_memory']:.1f}%
        """
        
        return response.choices[0].message.content.strip(), metrics_display
        
    except Exception as e:
        return str(e), "Metrics collection in progress"

# Create Gradio interface
demo = gr.Interface(
    fn=process_query,
    inputs=[
        gr.Textbox(label="Question", placeholder="Ask your question here"),
        gr.Dropdown(
            choices=["all"] + dataset_names,
            label="Select Dataset",
            value="all"
        )
    ],
    outputs=[
        gr.Textbox(label="Response"),
        gr.Textbox(label="Performance Metrics")
    ],
    title="E5-Powered Knowledge Base",
    description="Search across RagBench datasets with performance monitoring"
)

if __name__ == "__main__":
    demo.queue()
    demo.launch(debug=True)