Gradio / app.py
ajalisatgi's picture
Update app.py
2e77d5f verified
raw
history blame
3.69 kB
import gradio as gr
import openai
from datasets import load_dataset
import logging
import time
from langchain_community.embeddings import HuggingFaceEmbeddings
import torch
import psutil
import GPUtil
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize OpenAI API key
openai.api_key = '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'
# Initialize with E5 embedding model
model_name = 'intfloat/e5-base-v2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = HuggingFaceEmbeddings(model_name=model_name)
embedding_model.client.to(device)
# Load datasets
datasets = {}
dataset_names = ['covidqa', 'hotpotqa', 'pubmedqa'] # Starting with key datasets
for name in dataset_names:
datasets[name] = load_dataset("rungalileo/ragbench", name, split='train')
logger.info(f"Loaded {name}")
def get_system_metrics():
metrics = {
'cpu_percent': psutil.cpu_percent(),
'memory_percent': psutil.virtual_memory().percent,
'gpu_util': GPUtil.getGPUs()[0].load * 100 if torch.cuda.is_available() else 0,
'gpu_memory': GPUtil.getGPUs()[0].memoryUtil * 100 if torch.cuda.is_available() else 0
}
return metrics
def process_query(query, dataset_choice="all"):
start_time = time.time()
try:
relevant_contexts = []
search_datasets = [dataset_choice] if dataset_choice != "all" else datasets.keys()
for dataset_name in search_datasets:
if dataset_name in datasets:
for doc in datasets[dataset_name]['documents']:
if any(keyword.lower() in doc.lower() for keyword in query.split()):
relevant_contexts.append((doc, dataset_name))
context_info = f"From {relevant_contexts[0][1]}: {relevant_contexts[0][0]}" if relevant_contexts else "Searching across datasets..."
response = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a knowledgeable expert using E5 embeddings for precise information retrieval."},
{"role": "user", "content": f"Context: {context_info}\nQuestion: {query}"}
],
max_tokens=300,
temperature=0.7,
)
# Get performance metrics
metrics = get_system_metrics()
metrics['processing_time'] = time.time() - start_time
metrics_display = f"""
Processing Time: {metrics['processing_time']:.2f}s
CPU Usage: {metrics['cpu_percent']}%
Memory Usage: {metrics['memory_percent']}%
GPU Utilization: {metrics['gpu_util']:.1f}%
GPU Memory: {metrics['gpu_memory']:.1f}%
"""
return response.choices[0].message.content.strip(), metrics_display
except Exception as e:
return str(e), "Metrics collection in progress"
# Create Gradio interface
demo = gr.Interface(
fn=process_query,
inputs=[
gr.Textbox(label="Question", placeholder="Ask your question here"),
gr.Dropdown(
choices=["all"] + dataset_names,
label="Select Dataset",
value="all"
)
],
outputs=[
gr.Textbox(label="Response"),
gr.Textbox(label="Performance Metrics")
],
title="E5-Powered Knowledge Base",
description="Search across RagBench datasets with performance monitoring"
)
if __name__ == "__main__":
demo.queue()
demo.launch(debug=True)