|
import os |
|
import gradio as gr |
|
import pandas as pd |
|
from datetime import datetime |
|
from pydantic import BaseModel, Field |
|
from typing import List, Dict, Any, Optional |
|
import numpy as np |
|
from mistralai import Mistral |
|
from openai import OpenAI |
|
import re |
|
import json |
|
import logging |
|
import time |
|
import concurrent.futures |
|
from concurrent.futures import ThreadPoolExecutor |
|
import threading |
|
import pymongo |
|
from pymongo import MongoClient |
|
from bson.objectid import ObjectId |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s [%(levelname)s] %(message)s', |
|
handlers=[ |
|
logging.StreamHandler() |
|
] |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class HallucinationJudgment(BaseModel): |
|
hallucination_detected: bool = Field(description="Whether a hallucination is detected across the responses") |
|
confidence_score: float = Field(description="Confidence score between 0-1 for the hallucination judgment") |
|
conflicting_facts: List[Dict[str, Any]] = Field(description="List of conflicting facts found in the responses") |
|
reasoning: str = Field(description="Detailed reasoning for the judgment") |
|
summary: str = Field(description="A summary of the analysis") |
|
|
|
class PAS2: |
|
"""Paraphrase-based Approach for LLM Systems - Using llm-as-judge methods""" |
|
|
|
def __init__(self, mistral_api_key=None, openai_api_key=None, progress_callback=None): |
|
"""Initialize the PAS2 with API keys""" |
|
|
|
|
|
self.mistral_api_key = mistral_api_key or os.environ.get("HF_MISTRAL_API_KEY") or os.environ.get("MISTRAL_API_KEY") |
|
self.openai_api_key = openai_api_key or os.environ.get("HF_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY") |
|
self.progress_callback = progress_callback |
|
|
|
if not self.mistral_api_key: |
|
raise ValueError("Mistral API key is required. Set it via HF_MISTRAL_API_KEY in Hugging Face Spaces secrets or pass it as a parameter.") |
|
|
|
if not self.openai_api_key: |
|
raise ValueError("OpenAI API key is required. Set it via HF_OPENAI_API_KEY in Hugging Face Spaces secrets or pass it as a parameter.") |
|
|
|
self.mistral_client = Mistral(api_key=self.mistral_api_key) |
|
self.openai_client = OpenAI(api_key=self.openai_api_key) |
|
|
|
self.mistral_model = "mistral-large-latest" |
|
self.openai_model = "o3-mini" |
|
|
|
logger.info("PAS2 initialized with Mistral model: %s and OpenAI model: %s", |
|
self.mistral_model, self.openai_model) |
|
|
|
def generate_paraphrases(self, query: str, n_paraphrases: int = 3) -> List[str]: |
|
"""Generate paraphrases of the input query using Mistral API""" |
|
logger.info("Generating %d paraphrases for query: %s", n_paraphrases, query) |
|
start_time = time.time() |
|
|
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": f"You are an expert at creating semantically equivalent paraphrases. Generate {n_paraphrases} different paraphrases of the given query that preserve the original meaning but vary in wording and structure. Return a JSON array of strings, each containing one paraphrase." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": query |
|
} |
|
] |
|
|
|
try: |
|
logger.info("Sending paraphrase generation request to Mistral API...") |
|
response = self.mistral_client.chat.complete( |
|
model=self.mistral_model, |
|
messages=messages, |
|
response_format={"type": "json_object"} |
|
) |
|
|
|
content = response.choices[0].message.content |
|
logger.debug("Received raw paraphrase response: %s", content) |
|
|
|
paraphrases_data = json.loads(content) |
|
|
|
|
|
if isinstance(paraphrases_data, dict) and "paraphrases" in paraphrases_data: |
|
paraphrases = paraphrases_data["paraphrases"] |
|
elif isinstance(paraphrases_data, dict) and "results" in paraphrases_data: |
|
paraphrases = paraphrases_data["results"] |
|
elif isinstance(paraphrases_data, list): |
|
paraphrases = paraphrases_data |
|
else: |
|
|
|
for key, value in paraphrases_data.items(): |
|
if isinstance(value, list) and len(value) > 0: |
|
paraphrases = value |
|
break |
|
else: |
|
logger.warning("Could not extract paraphrases from response: %s", content) |
|
raise ValueError(f"Could not extract paraphrases from response: {content}") |
|
|
|
|
|
paraphrases = paraphrases[:n_paraphrases] |
|
|
|
|
|
all_queries = [query] + paraphrases |
|
|
|
elapsed_time = time.time() - start_time |
|
logger.info("Generated %d paraphrases in %.2f seconds", len(paraphrases), elapsed_time) |
|
for i, p in enumerate(paraphrases, 1): |
|
logger.info("Paraphrase %d: %s", i, p) |
|
|
|
return all_queries |
|
|
|
except Exception as e: |
|
logger.error("Error generating paraphrases: %s", str(e), exc_info=True) |
|
|
|
fallback_paraphrases = [ |
|
query, |
|
f"Could you tell me about {query.strip('?')}?", |
|
f"I'd like to know: {query}", |
|
f"Please provide information on {query.strip('?')}." |
|
][:n_paraphrases+1] |
|
|
|
logger.info("Using fallback paraphrases due to error") |
|
for i, p in enumerate(fallback_paraphrases[1:], 1): |
|
logger.info("Fallback paraphrase %d: %s", i, p) |
|
|
|
return fallback_paraphrases |
|
|
|
def _get_single_response(self, query: str, index: int = None) -> str: |
|
"""Get a single response from Mistral API for a query""" |
|
try: |
|
query_description = f"Query {index}: {query}" if index is not None else f"Query: {query}" |
|
logger.info("Getting response for %s", query_description) |
|
start_time = time.time() |
|
|
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": "You are a helpful AI assistant. Provide accurate, factual information in response to questions." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": query |
|
} |
|
] |
|
|
|
response = self.mistral_client.chat.complete( |
|
model=self.mistral_model, |
|
messages=messages |
|
) |
|
|
|
result = response.choices[0].message.content |
|
elapsed_time = time.time() - start_time |
|
|
|
logger.info("Received response for %s (%.2f seconds)", query_description, elapsed_time) |
|
logger.debug("Response content for %s: %s", query_description, result[:100] + "..." if len(result) > 100 else result) |
|
|
|
return result |
|
|
|
except Exception as e: |
|
error_msg = f"Error getting response for query '{query}': {e}" |
|
logger.error(error_msg, exc_info=True) |
|
return f"Error: Failed to get response for this query." |
|
|
|
def get_responses(self, queries: List[str]) -> List[str]: |
|
"""Get responses from Mistral API for each query in parallel""" |
|
logger.info("Getting responses for %d queries in parallel", len(queries)) |
|
start_time = time.time() |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=min(len(queries), 5)) as executor: |
|
|
|
future_to_index = { |
|
executor.submit(self._get_single_response, query, i): i |
|
for i, query in enumerate(queries) |
|
} |
|
|
|
|
|
responses = [""] * len(queries) |
|
|
|
|
|
completed_count = 0 |
|
|
|
|
|
for future in concurrent.futures.as_completed(future_to_index): |
|
index = future_to_index[future] |
|
try: |
|
responses[index] = future.result() |
|
|
|
|
|
completed_count += 1 |
|
if self.progress_callback: |
|
self.progress_callback("responses_progress", |
|
completed_responses=completed_count, |
|
total_responses=len(queries)) |
|
|
|
except Exception as e: |
|
logger.error("Error processing response for index %d: %s", index, str(e)) |
|
responses[index] = f"Error: Failed to get response for query {index}." |
|
|
|
|
|
completed_count += 1 |
|
if self.progress_callback: |
|
self.progress_callback("responses_progress", |
|
completed_responses=completed_count, |
|
total_responses=len(queries)) |
|
|
|
elapsed_time = time.time() - start_time |
|
logger.info("Received all %d responses in %.2f seconds total", len(responses), elapsed_time) |
|
|
|
return responses |
|
|
|
def detect_hallucination(self, query: str, n_paraphrases: int = 3) -> Dict: |
|
""" |
|
Detect hallucinations by comparing responses to paraphrased queries using a judge model |
|
|
|
Returns: |
|
Dict containing hallucination judgment and all responses |
|
""" |
|
logger.info("Starting hallucination detection for query: %s", query) |
|
start_time = time.time() |
|
|
|
|
|
if self.progress_callback: |
|
self.progress_callback("starting", query=query) |
|
|
|
|
|
logger.info("Step 1: Generating paraphrases") |
|
if self.progress_callback: |
|
self.progress_callback("generating_paraphrases", query=query) |
|
|
|
all_queries = self.generate_paraphrases(query, n_paraphrases) |
|
|
|
if self.progress_callback: |
|
self.progress_callback("paraphrases_complete", query=query, count=len(all_queries)) |
|
|
|
|
|
logger.info("Step 2: Getting responses to all %d queries", len(all_queries)) |
|
if self.progress_callback: |
|
self.progress_callback("getting_responses", query=query, total=len(all_queries)) |
|
|
|
all_responses = [] |
|
for i, q in enumerate(all_queries): |
|
logger.info("Getting response %d/%d for query: %s", i+1, len(all_queries), q) |
|
if self.progress_callback: |
|
self.progress_callback("responses_progress", query=query, completed=i, total=len(all_queries)) |
|
|
|
response = self._get_single_response(q, index=i) |
|
all_responses.append(response) |
|
|
|
if self.progress_callback: |
|
self.progress_callback("responses_complete", query=query) |
|
|
|
|
|
logger.info("Step 3: Judging for hallucinations") |
|
if self.progress_callback: |
|
self.progress_callback("judging", query=query) |
|
|
|
|
|
original_query = all_queries[0] |
|
original_response = all_responses[0] |
|
paraphrased_queries = all_queries[1:] if len(all_queries) > 1 else [] |
|
paraphrased_responses = all_responses[1:] if len(all_responses) > 1 else [] |
|
|
|
|
|
judgment = self.judge_hallucination( |
|
original_query=original_query, |
|
original_response=original_response, |
|
paraphrased_queries=paraphrased_queries, |
|
paraphrased_responses=paraphrased_responses |
|
) |
|
|
|
|
|
results = { |
|
"original_query": original_query, |
|
"original_response": original_response, |
|
"paraphrased_queries": paraphrased_queries, |
|
"paraphrased_responses": paraphrased_responses, |
|
"hallucination_detected": judgment.hallucination_detected, |
|
"confidence_score": judgment.confidence_score, |
|
"conflicting_facts": judgment.conflicting_facts, |
|
"reasoning": judgment.reasoning, |
|
"summary": judgment.summary |
|
} |
|
|
|
|
|
if self.progress_callback: |
|
self.progress_callback("complete", query=query) |
|
|
|
logger.info("Hallucination detection completed in %.2f seconds", time.time() - start_time) |
|
return results |
|
|
|
def judge_hallucination(self, |
|
original_query: str, |
|
original_response: str, |
|
paraphrased_queries: List[str], |
|
paraphrased_responses: List[str]) -> HallucinationJudgment: |
|
""" |
|
Use OpenAI's o3-mini as a judge to detect hallucinations in the responses |
|
""" |
|
logger.info("Judging hallucinations with OpenAI's %s model", self.openai_model) |
|
start_time = time.time() |
|
|
|
|
|
context = f""" |
|
Original Question: {original_query} |
|
|
|
Original Response: |
|
{original_response} |
|
|
|
Paraphrased Questions and their Responses: |
|
""" |
|
|
|
for i, (query, response) in enumerate(zip(paraphrased_queries, paraphrased_responses), 1): |
|
context += f"\nParaphrased Question {i}: {query}\n\nResponse {i}:\n{response}\n" |
|
|
|
system_prompt = """ |
|
You are a judge evaluating whether an AI is hallucinating across different responses to semantically equivalent questions. |
|
Analyze all responses carefully to identify any factual inconsistencies or contradictions. |
|
Focus on factual discrepancies, not stylistic differences. |
|
A hallucination is when the AI states different facts in response to questions that are asking for the same information. |
|
|
|
Your response should be a JSON with the following fields: |
|
- hallucination_detected: boolean indicating whether hallucinations were found |
|
- confidence_score: number between 0 and 1 representing your confidence in the judgment |
|
- conflicting_facts: an array of objects describing any conflicting information found |
|
- reasoning: detailed explanation for your judgment |
|
- summary: a concise summary of your analysis |
|
""" |
|
|
|
try: |
|
logger.info("Sending judgment request to OpenAI API...") |
|
response = self.openai_client.chat.completions.create( |
|
model=self.openai_model, |
|
messages=[ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": f"Evaluate these responses for hallucinations:\n\n{context}"} |
|
], |
|
response_format={"type": "json_object"} |
|
) |
|
|
|
result_json = json.loads(response.choices[0].message.content) |
|
logger.debug("Received judgment response: %s", result_json) |
|
|
|
|
|
judgment = HallucinationJudgment( |
|
hallucination_detected=result_json.get("hallucination_detected", False), |
|
confidence_score=result_json.get("confidence_score", 0.0), |
|
conflicting_facts=result_json.get("conflicting_facts", []), |
|
reasoning=result_json.get("reasoning", "No reasoning provided."), |
|
summary=result_json.get("summary", "No summary provided.") |
|
) |
|
|
|
elapsed_time = time.time() - start_time |
|
logger.info("Judgment completed in %.2f seconds", elapsed_time) |
|
|
|
return judgment |
|
|
|
except Exception as e: |
|
logger.error("Error in hallucination judgment: %s", str(e), exc_info=True) |
|
|
|
return HallucinationJudgment( |
|
hallucination_detected=False, |
|
confidence_score=0.0, |
|
conflicting_facts=[], |
|
reasoning="Failed to obtain judgment from the model.", |
|
summary="Analysis failed due to API error." |
|
) |
|
|
|
|
|
class HallucinationDetectorApp: |
|
def __init__(self): |
|
self.pas2 = None |
|
logger.info("Initializing HallucinationDetectorApp") |
|
self._initialize_database() |
|
self.progress_callback = None |
|
|
|
def _initialize_database(self): |
|
"""Initialize MongoDB connection for persistent feedback storage""" |
|
try: |
|
|
|
mongo_uri = os.environ.get("MONGODB_URI") |
|
|
|
if not mongo_uri: |
|
logger.warning("MONGODB_URI not found in environment variables. Please set it in HuggingFace Spaces secrets.") |
|
logger.warning("Using a placeholder URI for now - connection will fail until proper URI is provided.") |
|
|
|
mongo_uri = "mongodb+srv://username:[email protected]/?retryWrites=true&w=majority" |
|
|
|
|
|
self.mongo_client = MongoClient(mongo_uri) |
|
|
|
|
|
self.db = self.mongo_client["hallucination_detector"] |
|
|
|
|
|
self.feedback_collection = self.db["feedback"] |
|
|
|
|
|
self.feedback_collection.create_index("timestamp") |
|
|
|
|
|
self.mongo_client.admin.command('ping') |
|
logger.info("MongoDB connection successful") |
|
|
|
except Exception as e: |
|
logger.error(f"Error initializing MongoDB: {str(e)}", exc_info=True) |
|
logger.warning("Proceeding without database connection. Data will not be saved persistently.") |
|
self.mongo_client = None |
|
self.db = None |
|
self.feedback_collection = None |
|
|
|
def set_progress_callback(self, callback): |
|
"""Set the progress callback function""" |
|
self.progress_callback = callback |
|
|
|
def initialize_api(self, mistral_api_key, openai_api_key): |
|
"""Initialize the PAS2 with API keys""" |
|
try: |
|
logger.info("Initializing PAS2 with API keys") |
|
self.pas2 = PAS2( |
|
mistral_api_key=mistral_api_key, |
|
openai_api_key=openai_api_key, |
|
progress_callback=self.progress_callback |
|
) |
|
logger.info("API initialization successful") |
|
return "API keys set successfully! You can now use the application." |
|
except Exception as e: |
|
logger.error("Error initializing API: %s", str(e), exc_info=True) |
|
return f"Error initializing API: {str(e)}" |
|
|
|
def process_query(self, query: str): |
|
"""Process the query using PAS2""" |
|
if not self.pas2: |
|
logger.error("PAS2 not initialized") |
|
return { |
|
"error": "Please set API keys first before processing queries." |
|
} |
|
|
|
if not query.strip(): |
|
logger.warning("Empty query provided") |
|
return { |
|
"error": "Please enter a query." |
|
} |
|
|
|
try: |
|
|
|
if self.progress_callback and self.pas2.progress_callback != self.progress_callback: |
|
self.pas2.progress_callback = self.progress_callback |
|
|
|
|
|
logger.info("Processing query with PAS2: %s", query) |
|
results = self.pas2.detect_hallucination(query) |
|
logger.info("Query processing completed successfully") |
|
return results |
|
except Exception as e: |
|
logger.error("Error processing query: %s", str(e), exc_info=True) |
|
return { |
|
"error": f"Error processing query: {str(e)}" |
|
} |
|
|
|
def save_feedback(self, results, feedback): |
|
"""Save results and user feedback to MongoDB""" |
|
try: |
|
logger.info("Saving user feedback: %s", feedback) |
|
|
|
if self.feedback_collection is None: |
|
logger.error("MongoDB connection not available. Cannot save feedback.") |
|
return "Database connection not available. Feedback not saved." |
|
|
|
|
|
document = { |
|
"timestamp": datetime.now(), |
|
"original_query": results.get('original_query', ''), |
|
"original_response": results.get('original_response', ''), |
|
"paraphrased_queries": results.get('paraphrased_queries', []), |
|
"paraphrased_responses": results.get('paraphrased_responses', []), |
|
"hallucination_detected": results.get('hallucination_detected', False), |
|
"confidence_score": results.get('confidence_score', 0.0), |
|
"conflicting_facts": results.get('conflicting_facts', []), |
|
"reasoning": results.get('reasoning', ''), |
|
"summary": results.get('summary', ''), |
|
"user_feedback": feedback |
|
} |
|
|
|
|
|
self.feedback_collection.insert_one(document) |
|
|
|
logger.info("Feedback saved successfully to MongoDB") |
|
return "Feedback saved successfully!" |
|
except Exception as e: |
|
logger.error("Error saving feedback: %s", str(e), exc_info=True) |
|
return f"Error saving feedback: {str(e)}" |
|
|
|
def get_feedback_stats(self): |
|
"""Get statistics about collected feedback from MongoDB""" |
|
try: |
|
if self.feedback_collection is None: |
|
logger.error("MongoDB connection not available. Cannot get feedback stats.") |
|
return None |
|
|
|
|
|
total_count = self.feedback_collection.count_documents({}) |
|
|
|
|
|
correct_predictions = 0 |
|
|
|
|
|
feedback_docs = list(self.feedback_collection.find({}, {"user_feedback": 1})) |
|
|
|
|
|
for doc in feedback_docs: |
|
if "user_feedback" in doc: |
|
|
|
if doc["user_feedback"].startswith("Yes"): |
|
correct_predictions += 1 |
|
|
|
|
|
accuracy = correct_predictions / max(total_count, 1) |
|
|
|
return { |
|
"total_feedback": total_count, |
|
"correct_predictions": correct_predictions, |
|
"accuracy": accuracy |
|
} |
|
except Exception as e: |
|
logger.error("Error getting feedback stats: %s", str(e), exc_info=True) |
|
return None |
|
|
|
def export_data_to_csv(self, filepath=None): |
|
"""Export all feedback data to a CSV file for analysis""" |
|
try: |
|
if self.feedback_collection is None: |
|
logger.error("MongoDB connection not available. Cannot export data.") |
|
return "Database connection not available. Cannot export data." |
|
|
|
|
|
cursor = self.feedback_collection.find({}) |
|
|
|
|
|
records = list(cursor) |
|
|
|
|
|
|
|
for record in records: |
|
|
|
record['_id'] = str(record['_id']) |
|
|
|
|
|
if 'timestamp' in record: |
|
record['timestamp'] = record['timestamp'].strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
|
|
if 'paraphrased_queries' in record: |
|
record['paraphrased_queries'] = json.dumps(record['paraphrased_queries']) |
|
if 'paraphrased_responses' in record: |
|
record['paraphrased_responses'] = json.dumps(record['paraphrased_responses']) |
|
if 'conflicting_facts' in record: |
|
record['conflicting_facts'] = json.dumps(record['conflicting_facts']) |
|
|
|
|
|
df = pd.DataFrame(records) |
|
|
|
|
|
if not filepath: |
|
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), |
|
f"hallucination_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv") |
|
|
|
|
|
df.to_csv(filepath, index=False) |
|
logger.info(f"Data successfully exported to {filepath}") |
|
|
|
return filepath |
|
except Exception as e: |
|
logger.error(f"Error exporting data: {str(e)}", exc_info=True) |
|
return f"Error exporting data: {str(e)}" |
|
|
|
def get_recent_queries(self, limit=10): |
|
"""Get most recent queries for display in the UI""" |
|
try: |
|
if self.feedback_collection is None: |
|
logger.error("MongoDB connection not available. Cannot get recent queries.") |
|
return [] |
|
|
|
|
|
cursor = self.feedback_collection.find( |
|
{}, |
|
{"original_query": 1, "hallucination_detected": 1, "timestamp": 1} |
|
).sort("timestamp", pymongo.DESCENDING).limit(limit) |
|
|
|
|
|
recent_queries = [] |
|
for doc in cursor: |
|
recent_queries.append({ |
|
"id": str(doc["_id"]), |
|
"query": doc["original_query"], |
|
"hallucination_detected": doc.get("hallucination_detected", False), |
|
"timestamp": doc["timestamp"].strftime("%Y-%m-%d %H:%M:%S") if isinstance(doc["timestamp"], datetime) else doc["timestamp"] |
|
}) |
|
|
|
return recent_queries |
|
except Exception as e: |
|
logger.error(f"Error getting recent queries: {str(e)}", exc_info=True) |
|
return [] |
|
|
|
def get_query_details(self, query_id): |
|
"""Get full details for a specific query by ID""" |
|
try: |
|
if self.feedback_collection is None: |
|
logger.error("MongoDB connection not available. Cannot get query details.") |
|
return None |
|
|
|
|
|
obj_id = ObjectId(query_id) |
|
|
|
|
|
doc = self.feedback_collection.find_one({"_id": obj_id}) |
|
|
|
if doc is None: |
|
logger.warning(f"No query found with ID {query_id}") |
|
return None |
|
|
|
|
|
doc["_id"] = str(doc["_id"]) |
|
|
|
|
|
if "timestamp" in doc and isinstance(doc["timestamp"], datetime): |
|
doc["timestamp"] = doc["timestamp"].strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
return doc |
|
except Exception as e: |
|
logger.error(f"Error getting query details: {str(e)}", exc_info=True) |
|
return None |
|
|
|
|
|
|
|
class ProgressTracker: |
|
"""Tracks progress of hallucination detection for UI updates""" |
|
|
|
STAGES = { |
|
"idle": {"status": "Ready", "progress": 0, "color": "#757575"}, |
|
"starting": {"status": "Starting process...", "progress": 5, "color": "#2196F3"}, |
|
"generating_paraphrases": {"status": "Generating paraphrases...", "progress": 15, "color": "#2196F3"}, |
|
"paraphrases_complete": {"status": "Paraphrases generated", "progress": 30, "color": "#2196F3"}, |
|
"getting_responses": {"status": "Getting responses (0/0)...", "progress": 35, "color": "#2196F3"}, |
|
"responses_progress": {"status": "Getting responses ({completed}/{total})...", "progress": 40, "color": "#2196F3"}, |
|
"responses_complete": {"status": "All responses received", "progress": 65, "color": "#2196F3"}, |
|
"judging": {"status": "Analyzing responses for hallucinations...", "progress": 70, "color": "#2196F3"}, |
|
"complete": {"status": "Analysis complete!", "progress": 100, "color": "#4CAF50"}, |
|
"error": {"status": "Error: {error_message}", "progress": 100, "color": "#F44336"} |
|
} |
|
|
|
def __init__(self): |
|
self.stage = "idle" |
|
self.stage_data = self.STAGES[self.stage].copy() |
|
self.query = "" |
|
self.completed_responses = 0 |
|
self.total_responses = 0 |
|
self.error_message = "" |
|
self._lock = threading.Lock() |
|
self._status_callback = None |
|
self._stop_event = threading.Event() |
|
self._update_thread = None |
|
|
|
def register_callback(self, callback_fn): |
|
"""Register callback function to update UI""" |
|
self._status_callback = callback_fn |
|
|
|
def update_stage(self, stage, **kwargs): |
|
"""Update the current stage and trigger callback""" |
|
with self._lock: |
|
if stage in self.STAGES: |
|
self.stage = stage |
|
self.stage_data = self.STAGES[stage].copy() |
|
|
|
|
|
for key, value in kwargs.items(): |
|
if key == 'query': |
|
self.query = value |
|
elif key == 'completed_responses': |
|
self.completed_responses = value |
|
elif key == 'total_responses': |
|
self.total_responses = value |
|
elif key == 'error_message': |
|
self.error_message = value |
|
|
|
|
|
if stage == 'responses_progress': |
|
self.stage_data['status'] = self.stage_data['status'].format( |
|
completed=self.completed_responses, |
|
total=self.total_responses |
|
) |
|
elif stage == 'error': |
|
self.stage_data['status'] = self.stage_data['status'].format( |
|
error_message=self.error_message |
|
) |
|
|
|
if self._status_callback: |
|
self._status_callback(self.get_html_status()) |
|
|
|
def get_html_status(self): |
|
"""Get HTML representation of current status""" |
|
progress_width = f"{self.stage_data['progress']}%" |
|
status_text = self.stage_data['status'] |
|
color = self.stage_data['color'] |
|
|
|
query_info = f'<div class="query-display">{self.query}</div>' if self.query else '' |
|
|
|
|
|
status_display = f'<div class="progress-status" style="color: {color};">{status_text}</div>' if self.stage != "idle" else '' |
|
|
|
html = f""" |
|
<div class="progress-container"> |
|
{query_info} |
|
{status_display} |
|
<div class="progress-bar-container"> |
|
<div class="progress-bar" style="width: {progress_width}; background-color: {color};"></div> |
|
</div> |
|
</div> |
|
""" |
|
return html |
|
|
|
def start_pulsing(self): |
|
"""Start a pulsing animation for the progress bar during long operations""" |
|
if self._update_thread and self._update_thread.is_alive(): |
|
return |
|
|
|
self._stop_event.clear() |
|
self._update_thread = threading.Thread(target=self._pulse_progress) |
|
self._update_thread.daemon = True |
|
self._update_thread.start() |
|
|
|
def stop_pulsing(self): |
|
"""Stop the pulsing animation""" |
|
self._stop_event.set() |
|
if self._update_thread: |
|
self._update_thread.join(0.5) |
|
|
|
def _pulse_progress(self): |
|
"""Animate the progress bar to show activity""" |
|
pulse_stages = ["⋯", "⋯⋯", "⋯⋯⋯", "⋯⋯", "⋯"] |
|
i = 0 |
|
while not self._stop_event.is_set(): |
|
with self._lock: |
|
if self.stage not in ["idle", "complete", "error"]: |
|
status_base = self.stage_data['status'].split("...")[0] if "..." in self.stage_data['status'] else self.stage_data['status'] |
|
self.stage_data['status'] = f"{status_base}... {pulse_stages[i]}" |
|
|
|
if self._status_callback: |
|
self._status_callback(self.get_html_status()) |
|
|
|
i = (i + 1) % len(pulse_stages) |
|
time.sleep(0.3) |
|
|
|
|
|
def create_interface(): |
|
"""Create Gradio interface""" |
|
detector = HallucinationDetectorApp() |
|
|
|
|
|
progress_tracker = ProgressTracker() |
|
|
|
|
|
try: |
|
detector.initialize_api( |
|
mistral_api_key=os.environ.get("HF_MISTRAL_API_KEY"), |
|
openai_api_key=os.environ.get("HF_OPENAI_API_KEY") |
|
) |
|
except Exception as e: |
|
print(f"Warning: Failed to initialize APIs from environment variables: {e}") |
|
print("Please make sure HF_MISTRAL_API_KEY and HF_OPENAI_API_KEY are set in your environment") |
|
|
|
|
|
css = """ |
|
.container { |
|
max-width: 1000px; |
|
margin: 0 auto; |
|
} |
|
.title { |
|
text-align: center; |
|
margin-bottom: 0.5em; |
|
color: #1a237e; |
|
font-weight: 600; |
|
} |
|
.subtitle { |
|
text-align: center; |
|
margin-bottom: 1.5em; |
|
color: #455a64; |
|
font-size: 1.2em; |
|
} |
|
.section-title { |
|
margin-top: 1em; |
|
margin-bottom: 0.5em; |
|
font-weight: bold; |
|
color: #283593; |
|
} |
|
.info-box { |
|
padding: 1.2em; |
|
border-radius: 8px; |
|
background-color: #f5f5f5; |
|
margin-bottom: 1em; |
|
box-shadow: 0 2px 5px rgba(0,0,0,0.05); |
|
} |
|
.hallucination-positive { |
|
padding: 1.2em; |
|
border-radius: 8px; |
|
background-color: #ffebee; |
|
border-left: 5px solid #f44336; |
|
margin-bottom: 1em; |
|
box-shadow: 0 2px 5px rgba(0,0,0,0.05); |
|
} |
|
.hallucination-negative { |
|
padding: 1.2em; |
|
border-radius: 8px; |
|
background-color: #e8f5e9; |
|
border-left: 5px solid #4caf50; |
|
margin-bottom: 1em; |
|
box-shadow: 0 2px 5px rgba(0,0,0,0.05); |
|
} |
|
.response-box { |
|
padding: 1.2em; |
|
border-radius: 8px; |
|
background-color: #f5f5f5; |
|
margin-bottom: 0.8em; |
|
box-shadow: 0 2px 5px rgba(0,0,0,0.05); |
|
} |
|
.example-queries { |
|
display: flex; |
|
flex-wrap: wrap; |
|
gap: 8px; |
|
margin-bottom: 15px; |
|
} |
|
.example-query { |
|
background-color: #e3f2fd; |
|
padding: 8px 15px; |
|
border-radius: 18px; |
|
font-size: 0.9em; |
|
cursor: pointer; |
|
transition: all 0.2s; |
|
border: 1px solid #bbdefb; |
|
} |
|
.example-query:hover { |
|
background-color: #bbdefb; |
|
box-shadow: 0 2px 5px rgba(0,0,0,0.1); |
|
} |
|
.stats-section { |
|
display: flex; |
|
justify-content: space-between; |
|
background-color: #e8eaf6; |
|
padding: 15px; |
|
border-radius: 8px; |
|
margin-bottom: 20px; |
|
} |
|
.stat-item { |
|
text-align: center; |
|
padding: 10px; |
|
} |
|
.stat-value { |
|
font-size: 1.5em; |
|
font-weight: bold; |
|
color: #303f9f; |
|
} |
|
.stat-label { |
|
font-size: 0.9em; |
|
color: #5c6bc0; |
|
} |
|
.feedback-section { |
|
border-top: 1px solid #e0e0e0; |
|
padding-top: 15px; |
|
margin-top: 20px; |
|
} |
|
footer { |
|
text-align: center; |
|
padding: 20px; |
|
margin-top: 30px; |
|
color: #9e9e9e; |
|
font-size: 0.9em; |
|
} |
|
.processing-status { |
|
padding: 12px; |
|
background-color: #fff3e0; |
|
border-left: 4px solid #ff9800; |
|
margin-bottom: 15px; |
|
font-weight: 500; |
|
color: #e65100; |
|
} |
|
.debug-panel { |
|
background-color: #f5f5f5; |
|
border: 1px solid #e0e0e0; |
|
border-radius: 4px; |
|
padding: 10px; |
|
margin-top: 15px; |
|
font-family: monospace; |
|
font-size: 0.9em; |
|
white-space: pre-wrap; |
|
max-height: 200px; |
|
overflow-y: auto; |
|
} |
|
.progress-container { |
|
padding: 15px; |
|
background-color: #fff; |
|
border-radius: 8px; |
|
box-shadow: 0 2px 5px rgba(0,0,0,0.05); |
|
margin-bottom: 15px; |
|
} |
|
.progress-status { |
|
font-weight: 500; |
|
margin-bottom: 8px; |
|
padding: 4px 0; |
|
font-size: 0.95em; |
|
} |
|
.progress-bar-container { |
|
background-color: #e0e0e0; |
|
height: 10px; |
|
border-radius: 5px; |
|
overflow: hidden; |
|
margin-bottom: 10px; |
|
box-shadow: inset 0 1px 3px rgba(0,0,0,0.1); |
|
} |
|
.progress-bar { |
|
height: 100%; |
|
transition: width 0.5s ease; |
|
background-image: linear-gradient(to right, #2196F3, #3f51b5); |
|
} |
|
.query-display { |
|
font-style: italic; |
|
color: #666; |
|
margin-bottom: 10px; |
|
background-color: #f5f5f5; |
|
padding: 8px; |
|
border-radius: 4px; |
|
border-left: 3px solid #2196F3; |
|
} |
|
""" |
|
|
|
|
|
example_queries = [ |
|
"Who was the first person to land on the moon?", |
|
"What is the capital of France?", |
|
"How many planets are in our solar system?", |
|
"Who wrote the novel 1984?", |
|
"What is the speed of light?", |
|
"What was the first computer?" |
|
] |
|
|
|
|
|
def update_progress_display(html): |
|
"""Update the progress display with the provided HTML""" |
|
return gr.update(visible=True, value=html) |
|
|
|
|
|
progress_tracker.register_callback(update_progress_display) |
|
|
|
|
|
detector.set_progress_callback(progress_tracker.update_stage) |
|
|
|
|
|
def set_example_query(example): |
|
return example |
|
|
|
|
|
def start_processing(query): |
|
logger.info("Processing query: %s", query) |
|
|
|
progress_tracker.stop_pulsing() |
|
|
|
|
|
|
|
progress_tracker.stage = "starting" |
|
progress_tracker.query = query |
|
|
|
|
|
if progress_tracker._status_callback: |
|
progress_tracker._status_callback(progress_tracker.get_html_status()) |
|
|
|
return [ |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
None |
|
] |
|
|
|
|
|
def process_query_and_display_results(query, progress=gr.Progress()): |
|
if not query.strip(): |
|
logger.warning("Empty query submitted") |
|
progress_tracker.stop_pulsing() |
|
progress_tracker.update_stage("error", error_message="Please enter a query.") |
|
return [ |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
None |
|
] |
|
|
|
|
|
if not detector.pas2: |
|
try: |
|
|
|
logger.info("Initializing APIs from environment variables") |
|
progress(0.05, desc="Initializing API...") |
|
init_message = detector.initialize_api( |
|
mistral_api_key=os.environ.get("HF_MISTRAL_API_KEY"), |
|
openai_api_key=os.environ.get("HF_OPENAI_API_KEY") |
|
) |
|
if "successfully" not in init_message: |
|
logger.error("Failed to initialize APIs: %s", init_message) |
|
progress_tracker.stop_pulsing() |
|
progress_tracker.update_stage("error", error_message="API keys not found in environment variables.") |
|
return [ |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
None |
|
] |
|
except Exception as e: |
|
logger.error("Error initializing API: %s", str(e), exc_info=True) |
|
progress_tracker.stop_pulsing() |
|
progress_tracker.update_stage("error", error_message=f"Error initializing API: {str(e)}") |
|
return [ |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
None |
|
] |
|
|
|
try: |
|
|
|
logger.info("Starting hallucination detection process") |
|
start_time = time.time() |
|
|
|
|
|
def combined_progress_callback(stage, **kwargs): |
|
|
|
if stage == "idle": |
|
return |
|
|
|
progress_tracker.update_stage(stage, **kwargs) |
|
|
|
|
|
stage_to_progress = { |
|
"starting": 0.05, |
|
"generating_paraphrases": 0.15, |
|
"paraphrases_complete": 0.3, |
|
"getting_responses": 0.35, |
|
"responses_progress": lambda kwargs: 0.35 + (0.3 * (kwargs.get("completed", 0) / max(kwargs.get("total", 1), 1))), |
|
"responses_complete": 0.65, |
|
"judging": 0.7, |
|
"complete": 1.0, |
|
"error": 1.0 |
|
} |
|
|
|
|
|
if stage in stage_to_progress: |
|
prog_value = stage_to_progress[stage] |
|
if callable(prog_value): |
|
prog_value = prog_value(kwargs) |
|
|
|
desc = progress_tracker.STAGES[stage]["status"] |
|
if "{" in desc and "}" in desc: |
|
|
|
desc = desc.format(**kwargs) |
|
|
|
|
|
|
|
progress(prog_value, desc=desc) |
|
|
|
|
|
if stage in ["starting", "generating_paraphrases", "paraphrases_complete", |
|
"getting_responses", "responses_complete", "judging", "complete"]: |
|
time.sleep(0.2) |
|
|
|
|
|
detector.set_progress_callback(combined_progress_callback) |
|
|
|
|
|
def run_detection_with_visible_progress(): |
|
|
|
combined_progress_callback("starting", query=query) |
|
time.sleep(0.3) |
|
|
|
|
|
combined_progress_callback("generating_paraphrases", query=query) |
|
all_queries = detector.pas2.generate_paraphrases(query) |
|
combined_progress_callback("paraphrases_complete", query=query, count=len(all_queries)) |
|
|
|
|
|
combined_progress_callback("getting_responses", query=query, total=len(all_queries)) |
|
all_responses = [] |
|
for i, q in enumerate(all_queries): |
|
|
|
combined_progress_callback("responses_progress", query=query, completed=i, total=len(all_queries)) |
|
response = detector.pas2._get_single_response(q, index=i) |
|
all_responses.append(response) |
|
combined_progress_callback("responses_complete", query=query) |
|
|
|
|
|
combined_progress_callback("judging", query=query) |
|
|
|
|
|
original_query = all_queries[0] |
|
original_response = all_responses[0] |
|
paraphrased_queries = all_queries[1:] if len(all_queries) > 1 else [] |
|
paraphrased_responses = all_responses[1:] if len(all_responses) > 1 else [] |
|
|
|
|
|
judgment = detector.pas2.judge_hallucination( |
|
original_query=original_query, |
|
original_response=original_response, |
|
paraphrased_queries=paraphrased_queries, |
|
paraphrased_responses=paraphrased_responses |
|
) |
|
|
|
|
|
results = { |
|
"original_query": original_query, |
|
"original_response": original_response, |
|
"paraphrased_queries": paraphrased_queries, |
|
"paraphrased_responses": paraphrased_responses, |
|
"hallucination_detected": judgment.hallucination_detected, |
|
"confidence_score": judgment.confidence_score, |
|
"conflicting_facts": judgment.conflicting_facts, |
|
"reasoning": judgment.reasoning, |
|
"summary": judgment.summary |
|
} |
|
|
|
|
|
combined_progress_callback("complete", query=query) |
|
time.sleep(0.3) |
|
|
|
return results |
|
|
|
|
|
results = run_detection_with_visible_progress() |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
logger.info("Hallucination detection completed in %.2f seconds", elapsed_time) |
|
|
|
|
|
if "error" in results: |
|
logger.error("Error in results: %s", results["error"]) |
|
progress_tracker.stop_pulsing() |
|
progress_tracker.update_stage("error", error_message=results["error"]) |
|
return [ |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
None |
|
] |
|
|
|
|
|
original_query = results["original_query"] |
|
original_response = results["original_response"] |
|
|
|
paraphrased_queries = results["paraphrased_queries"] |
|
paraphrased_responses = results["paraphrased_responses"] |
|
|
|
hallucination_detected = results["hallucination_detected"] |
|
confidence = results["confidence_score"] |
|
reasoning = results["reasoning"] |
|
summary = results["summary"] |
|
|
|
|
|
conflicting_facts = results["conflicting_facts"] |
|
conflicting_facts_text = "" |
|
if conflicting_facts: |
|
for i, fact in enumerate(conflicting_facts, 1): |
|
conflicting_facts_text += f"{i}. " |
|
if isinstance(fact, dict): |
|
for key, value in fact.items(): |
|
conflicting_facts_text += f"{key}: {value}, " |
|
conflicting_facts_text = conflicting_facts_text.rstrip(", ") |
|
else: |
|
conflicting_facts_text += str(fact) |
|
conflicting_facts_text += "\n" |
|
|
|
|
|
original_response_safe = original_response.replace('\\', '\\\\').replace('\n', '<br>') |
|
paraphrased_responses_safe = [r.replace('\\', '\\\\').replace('\n', '<br>') for r in paraphrased_responses] |
|
reasoning_safe = reasoning.replace('\\', '\\\\').replace('\n', '<br>') |
|
conflicting_facts_text_safe = conflicting_facts_text.replace('\\', '\\\\').replace('\n', '<br>') if conflicting_facts_text else "None identified" |
|
|
|
html_output = f""" |
|
<div class="container"> |
|
<h2 class="title">Hallucination Detection Results</h2> |
|
|
|
<div class="stats-section"> |
|
<div class="stat-item"> |
|
<div class="stat-value">{'Yes' if hallucination_detected else 'No'}</div> |
|
<div class="stat-label">Hallucination Detected</div> |
|
</div> |
|
<div class="stat-item"> |
|
<div class="stat-value">{confidence:.2f}</div> |
|
<div class="stat-label">Confidence Score</div> |
|
</div> |
|
<div class="stat-item"> |
|
<div class="stat-value">{len(paraphrased_queries)}</div> |
|
<div class="stat-label">Paraphrases Analyzed</div> |
|
</div> |
|
<div class="stat-item"> |
|
<div class="stat-value">{elapsed_time:.1f}s</div> |
|
<div class="stat-label">Processing Time</div> |
|
</div> |
|
</div> |
|
|
|
<div class="{'hallucination-positive' if hallucination_detected else 'hallucination-negative'}"> |
|
<h3>Analysis Summary</h3> |
|
<p>{summary}</p> |
|
</div> |
|
|
|
<div class="section-title">Original Query</div> |
|
<div class="response-box"> |
|
{original_query} |
|
</div> |
|
|
|
<div class="section-title">Original Response</div> |
|
<div class="response-box"> |
|
{original_response_safe} |
|
</div> |
|
|
|
<div class="section-title">Paraphrased Queries and Responses</div> |
|
""" |
|
|
|
for i, (q, r) in enumerate(zip(paraphrased_queries, paraphrased_responses_safe), 1): |
|
html_output += f""" |
|
<div class="section-title">Paraphrased Query {i}</div> |
|
<div class="response-box"> |
|
{q} |
|
</div> |
|
|
|
<div class="section-title">Response {i}</div> |
|
<div class="response-box"> |
|
{r} |
|
</div> |
|
""" |
|
|
|
html_output += f""" |
|
<div class="section-title">Detailed Analysis</div> |
|
<div class="info-box"> |
|
<p><strong>Reasoning:</strong></p> |
|
<p>{reasoning_safe}</p> |
|
|
|
<p><strong>Conflicting Facts:</strong></p> |
|
<p>{conflicting_facts_text_safe}</p> |
|
</div> |
|
</div> |
|
""" |
|
|
|
logger.info("Updating UI with results") |
|
progress_tracker.stop_pulsing() |
|
|
|
return [ |
|
gr.update(visible=False), |
|
gr.update(visible=True, value=html_output), |
|
gr.update(visible=True), |
|
results |
|
] |
|
|
|
except Exception as e: |
|
logger.error("Error processing query: %s", str(e), exc_info=True) |
|
progress_tracker.stop_pulsing() |
|
progress_tracker.update_stage("error", error_message=f"Error processing query: {str(e)}") |
|
return [ |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
None |
|
] |
|
|
|
|
|
def combine_feedback(fb_input, fb_text, results): |
|
combined_feedback = f"{fb_input}: {fb_text}" if fb_text else fb_input |
|
if not results: |
|
return "No results to attach feedback to." |
|
|
|
response = detector.save_feedback(results, combined_feedback) |
|
return response |
|
|
|
|
|
with gr.Blocks(css=css, theme=gr.themes.Soft()) as interface: |
|
gr.HTML( |
|
""" |
|
<div style="text-align: center; margin-bottom: 1.5rem"> |
|
<h1 style="font-size: 2.2em; font-weight: 600; color: #1a237e; margin-bottom: 0.2em;">PAS2 - Hallucination Detector</h1> |
|
<h3 style="font-size: 1.3em; color: #455a64; margin-bottom: 0.8em;">Advanced AI Response Verification Using Model-as-Judge</h3> |
|
<p style="font-size: 1.1em; color: #546e7a; max-width: 800px; margin: 0 auto;"> |
|
This tool detects hallucinations in AI responses by comparing answers to semantically equivalent questions and using a specialized judge model. |
|
</p> |
|
</div> |
|
""" |
|
) |
|
|
|
with gr.Accordion("About this Tool", open=False): |
|
gr.Markdown( |
|
""" |
|
### How It Works |
|
|
|
This tool implements the Paraphrase-based Approach for Scrutinizing Systems (PAS2) with a model-as-judge enhancement: |
|
|
|
1. **Paraphrase Generation**: Your question is paraphrased multiple ways while preserving its core meaning |
|
2. **Multiple Responses**: All questions (original + paraphrases) are sent to Mistral Large model |
|
3. **Expert Judgment**: OpenAI's o3-mini analyzes all responses to detect factual inconsistencies |
|
|
|
### Why This Approach? |
|
|
|
When an AI hallucinates, it often provides different answers to the same question when phrased differently. |
|
By using a separate judge model, we can identify these inconsistencies more effectively than with |
|
metric-based approaches. |
|
|
|
### Understanding the Results |
|
|
|
- **Confidence Score**: Indicates the judge's confidence in the hallucination detection |
|
- **Conflicting Facts**: Specific inconsistencies found across responses |
|
- **Reasoning**: The judge's detailed analysis explaining its decision |
|
|
|
### Privacy Notice |
|
|
|
Your queries and the system's responses are saved to help improve hallucination detection. |
|
No personally identifiable information is collected. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
gr.Markdown("### Enter Your Question") |
|
with gr.Row(): |
|
query_input = gr.Textbox( |
|
label="", |
|
placeholder="Ask a factual question (e.g., Who was the first person to land on the moon?)", |
|
lines=3 |
|
) |
|
|
|
|
|
gr.Markdown("### Or Try an Example") |
|
example_row = gr.Row() |
|
with example_row: |
|
for example in example_queries: |
|
example_btn = gr.Button( |
|
example, |
|
elem_classes=["example-query"], |
|
scale=0 |
|
) |
|
example_btn.click( |
|
fn=set_example_query, |
|
inputs=[gr.Textbox(value=example, visible=False)], |
|
outputs=[query_input] |
|
) |
|
|
|
with gr.Row(): |
|
submit_button = gr.Button("Detect Hallucinations", variant="primary", scale=1) |
|
|
|
|
|
error_message = gr.HTML( |
|
label="Status", |
|
visible=False |
|
) |
|
|
|
|
|
progress_display = gr.HTML( |
|
value=progress_tracker.get_html_status(), |
|
visible=True |
|
) |
|
|
|
|
|
results_accordion = gr.HTML(visible=False) |
|
|
|
|
|
feedback_stats = gr.HTML(visible=True) |
|
|
|
|
|
def update_stats(): |
|
stats = detector.get_feedback_stats() |
|
if stats: |
|
total = stats['total_feedback'] |
|
correct = stats['correct_predictions'] |
|
|
|
|
|
accuracy = stats['accuracy'] |
|
|
|
|
|
accuracy_pct = f"{accuracy * 100:.1f}%" |
|
|
|
stats_html = f""" |
|
<div class="stats-section" style="background-color: #e8f5e9; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); margin-top: 5px;"> |
|
<div class="stat-item"> |
|
<div class="stat-value" style="font-size: 2em; color: #2e7d32;">{total}</div> |
|
<div class="stat-label" style="font-weight: bold;">Total Responses</div> |
|
</div> |
|
<div class="stat-item"> |
|
<div class="stat-value" style="font-size: 2em; color: #2e7d32;">{accuracy_pct}</div> |
|
<div class="stat-label" style="font-weight: bold;">Correct Predictions</div> |
|
</div> |
|
</div> |
|
<div style="text-align: center; margin-top: 10px; font-style: italic; color: #666;"> |
|
Based on user feedback: {correct} correct out of {total} total predictions |
|
</div> |
|
""" |
|
return stats_html |
|
return "" |
|
|
|
|
|
with gr.Row(elem_id="stats-container"): |
|
with gr.Column(): |
|
gr.Markdown("### 📊 Live Prediction Accuracy") |
|
gr.Markdown("_Auto-refreshes every 5 seconds from MongoDB based on user feedback_") |
|
live_stats = gr.HTML(update_stats()) |
|
|
|
|
|
gr.HTML(""" |
|
<style> |
|
@keyframes pulse { |
|
0% { opacity: 0.6; } |
|
50% { opacity: 1; } |
|
100% { opacity: 0.6; } |
|
} |
|
.refreshing::after { |
|
content: "⟳"; |
|
display: inline-block; |
|
margin-left: 8px; |
|
animation: pulse 1.5s infinite ease-in-out; |
|
color: #2e7d32; |
|
} |
|
#stats-container { |
|
border: 1px solid #e0e0e0; |
|
border-radius: 10px; |
|
padding: 15px; |
|
margin: 10px 0; |
|
background-color: #2762d7; |
|
} |
|
</style> |
|
<div class="refreshing" style="text-align: right; font-size: 0.8em; color: #666;">Auto-refreshing</div> |
|
""") |
|
|
|
|
|
refresh_btn = gr.Button("Refresh Stats", visible=False) |
|
refresh_btn.click( |
|
fn=update_stats, |
|
outputs=[live_stats] |
|
) |
|
|
|
|
|
gr.HTML(""" |
|
<script> |
|
// Auto-refresh stats every 5 seconds |
|
function setupAutoRefresh() { |
|
const refreshInterval = 5000; // 5 seconds |
|
setInterval(() => { |
|
// Find the refresh button by its text and click it |
|
const refreshButtons = Array.from(document.querySelectorAll('button')); |
|
const refreshBtn = refreshButtons.find(btn => btn.textContent.includes('Refresh Stats')); |
|
if (refreshBtn) { |
|
refreshBtn.click(); |
|
} |
|
}, refreshInterval); |
|
} |
|
|
|
// Set up the auto-refresh after the page loads |
|
if (window.gradio_loaded) { |
|
setupAutoRefresh(); |
|
} else { |
|
document.addEventListener('DOMContentLoaded', setupAutoRefresh); |
|
} |
|
</script> |
|
""") |
|
|
|
|
|
with gr.Accordion("Provide Feedback", open=False, visible=False) as feedback_accordion: |
|
gr.Markdown("### Help Improve the System") |
|
gr.Markdown("Your feedback helps us refine the hallucination detection system.") |
|
|
|
feedback_input = gr.Radio( |
|
label="Is the hallucination detection accurate?", |
|
choices=["Yes, correct detection", "No, incorrectly flagged hallucination", "No, missed hallucination", "Unsure/Other"], |
|
value="Yes, correct detection" |
|
) |
|
|
|
feedback_text = gr.Textbox( |
|
label="Additional comments (optional)", |
|
placeholder="Please provide any additional observations or details...", |
|
lines=2 |
|
) |
|
|
|
feedback_button = gr.Button("Submit Feedback", variant="secondary") |
|
feedback_status = gr.Textbox(label="Feedback Status", interactive=False, visible=False) |
|
|
|
|
|
|
|
|
|
hidden_results = gr.State() |
|
|
|
|
|
submit_button.click( |
|
fn=start_processing, |
|
inputs=[query_input], |
|
outputs=[progress_display, results_accordion, feedback_accordion, hidden_results], |
|
queue=False |
|
).then( |
|
fn=process_query_and_display_results, |
|
inputs=[query_input], |
|
outputs=[progress_display, results_accordion, feedback_accordion, hidden_results] |
|
) |
|
|
|
feedback_button.click( |
|
fn=combine_feedback, |
|
inputs=[feedback_input, feedback_text, hidden_results], |
|
outputs=[feedback_status] |
|
) |
|
|
|
|
|
gr.HTML( |
|
""" |
|
<footer> |
|
<p>Paraphrase-based Approach for Scrutinizing Systems (PAS2) - Advanced Hallucination Detection</p> |
|
<p>Using Mistral Large for generation and OpenAI o3-mini as judge</p> |
|
</footer> |
|
""" |
|
) |
|
|
|
return interface |
|
|
|
|
|
def test_progress(): |
|
"""Simple test function to demonstrate progress bar""" |
|
import gradio as gr |
|
import time |
|
|
|
def slow_process(progress=gr.Progress()): |
|
progress(0, desc="Starting process...") |
|
time.sleep(0.5) |
|
|
|
|
|
progress(0.15, desc="Generating paraphrases...") |
|
time.sleep(1) |
|
progress(0.3, desc="Paraphrases generated") |
|
time.sleep(0.5) |
|
|
|
|
|
progress(0.35, desc="Getting responses...") |
|
|
|
for i in range(3): |
|
time.sleep(0.8) |
|
prog = 0.35 + (0.3 * ((i+1) / 3)) |
|
progress(prog, desc=f"Getting responses ({i+1}/3)...") |
|
|
|
progress(0.65, desc="All responses received") |
|
time.sleep(0.5) |
|
|
|
|
|
progress(0.7, desc="Analyzing responses for hallucinations...") |
|
time.sleep(2) |
|
|
|
|
|
progress(1.0, desc="Analysis complete!") |
|
return "Process completed successfully!" |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
btn = gr.Button("Start Process") |
|
output = gr.Textbox(label="Result") |
|
|
|
btn.click(fn=slow_process, outputs=output) |
|
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
logger.info("Starting PAS2 Hallucination Detector") |
|
interface = create_interface() |
|
logger.info("Launching Gradio interface...") |
|
interface.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_api=False, |
|
quiet=True, |
|
share=False, |
|
max_threads=10, |
|
debug=False |
|
) |
|
|
|
|
|
|
|
|
|
|