Serhan Yılmaz commited on
Commit
96f160f
·
1 Parent(s): 696be0a
Files changed (2) hide show
  1. app.py +1469 -7
  2. pas2.py +0 -1483
app.py CHANGED
@@ -1,5 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
- from pas2 import create_interface
 
 
 
 
3
 
4
  # Configure logging
5
  logging.basicConfig(
@@ -12,10 +27,1457 @@ logging.basicConfig(
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
- # Create and launch the interface
16
- logger.info("Starting PAS2 Hallucination Detector")
17
- interface = create_interface()
18
- logger.info("Launching Gradio interface...")
 
 
19
 
20
- # This is the entry point for Hugging Face Spaces
21
- app = interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import pandas as pd
4
+ from datetime import datetime
5
+ from pydantic import BaseModel, Field
6
+ from typing import List, Dict, Any, Optional
7
+ import numpy as np
8
+ from mistralai import Mistral
9
+ from openai import OpenAI
10
+ import re
11
+ import json
12
  import logging
13
+ import time
14
+ import concurrent.futures
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ import threading
17
+ import sqlite3
18
 
19
  # Configure logging
20
  logging.basicConfig(
 
27
 
28
  logger = logging.getLogger(__name__)
29
 
30
+ class HallucinationJudgment(BaseModel):
31
+ hallucination_detected: bool = Field(description="Whether a hallucination is detected across the responses")
32
+ confidence_score: float = Field(description="Confidence score between 0-1 for the hallucination judgment")
33
+ conflicting_facts: List[Dict[str, Any]] = Field(description="List of conflicting facts found in the responses")
34
+ reasoning: str = Field(description="Detailed reasoning for the judgment")
35
+ summary: str = Field(description="A summary of the analysis")
36
 
37
+ class PAS2:
38
+ """Paraphrase-based Approach for Scrutinizing Systems - Using model-as-judge"""
39
+
40
+ def __init__(self, mistral_api_key=None, openai_api_key=None, progress_callback=None):
41
+ """Initialize the PAS2 with API keys"""
42
+ # For Hugging Face Spaces, we prioritize getting API keys from HF_* environment variables
43
+ # which are set from the Secrets tab in the Space settings
44
+ self.mistral_api_key = mistral_api_key or os.environ.get("HF_MISTRAL_API_KEY") or os.environ.get("MISTRAL_API_KEY")
45
+ self.openai_api_key = openai_api_key or os.environ.get("HF_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY")
46
+ self.progress_callback = progress_callback
47
+
48
+ if not self.mistral_api_key:
49
+ raise ValueError("Mistral API key is required. Set it via HF_MISTRAL_API_KEY in Hugging Face Spaces secrets or pass it as a parameter.")
50
+
51
+ if not self.openai_api_key:
52
+ raise ValueError("OpenAI API key is required. Set it via HF_OPENAI_API_KEY in Hugging Face Spaces secrets or pass it as a parameter.")
53
+
54
+ self.mistral_client = Mistral(api_key=self.mistral_api_key)
55
+ self.openai_client = OpenAI(api_key=self.openai_api_key)
56
+
57
+ self.mistral_model = "mistral-large-latest"
58
+ self.openai_model = "o3-mini"
59
+
60
+ logger.info("PAS2 initialized with Mistral model: %s and OpenAI model: %s",
61
+ self.mistral_model, self.openai_model)
62
+
63
+ def generate_paraphrases(self, query: str, n_paraphrases: int = 3) -> List[str]:
64
+ """Generate paraphrases of the input query using Mistral API"""
65
+ logger.info("Generating %d paraphrases for query: %s", n_paraphrases, query)
66
+ start_time = time.time()
67
+
68
+ messages = [
69
+ {
70
+ "role": "system",
71
+ "content": f"You are an expert at creating semantically equivalent paraphrases. Generate {n_paraphrases} different paraphrases of the given query that preserve the original meaning but vary in wording and structure. Return a JSON array of strings, each containing one paraphrase."
72
+ },
73
+ {
74
+ "role": "user",
75
+ "content": query
76
+ }
77
+ ]
78
+
79
+ try:
80
+ logger.info("Sending paraphrase generation request to Mistral API...")
81
+ response = self.mistral_client.chat.complete(
82
+ model=self.mistral_model,
83
+ messages=messages,
84
+ response_format={"type": "json_object"}
85
+ )
86
+
87
+ content = response.choices[0].message.content
88
+ logger.debug("Received raw paraphrase response: %s", content)
89
+
90
+ paraphrases_data = json.loads(content)
91
+
92
+ # Handle different possible JSON structures
93
+ if isinstance(paraphrases_data, dict) and "paraphrases" in paraphrases_data:
94
+ paraphrases = paraphrases_data["paraphrases"]
95
+ elif isinstance(paraphrases_data, dict) and "results" in paraphrases_data:
96
+ paraphrases = paraphrases_data["results"]
97
+ elif isinstance(paraphrases_data, list):
98
+ paraphrases = paraphrases_data
99
+ else:
100
+ # Try to extract a list from any field
101
+ for key, value in paraphrases_data.items():
102
+ if isinstance(value, list) and len(value) > 0:
103
+ paraphrases = value
104
+ break
105
+ else:
106
+ logger.warning("Could not extract paraphrases from response: %s", content)
107
+ raise ValueError(f"Could not extract paraphrases from response: {content}")
108
+
109
+ # Ensure we have the right number of paraphrases
110
+ paraphrases = paraphrases[:n_paraphrases]
111
+
112
+ # Add the original query as the first item
113
+ all_queries = [query] + paraphrases
114
+
115
+ elapsed_time = time.time() - start_time
116
+ logger.info("Generated %d paraphrases in %.2f seconds", len(paraphrases), elapsed_time)
117
+ for i, p in enumerate(paraphrases, 1):
118
+ logger.info("Paraphrase %d: %s", i, p)
119
+
120
+ return all_queries
121
+
122
+ except Exception as e:
123
+ logger.error("Error generating paraphrases: %s", str(e), exc_info=True)
124
+ # Return original plus simple paraphrases as fallback
125
+ fallback_paraphrases = [
126
+ query,
127
+ f"Could you tell me about {query.strip('?')}?",
128
+ f"I'd like to know: {query}",
129
+ f"Please provide information on {query.strip('?')}."
130
+ ][:n_paraphrases+1]
131
+
132
+ logger.info("Using fallback paraphrases due to error")
133
+ for i, p in enumerate(fallback_paraphrases[1:], 1):
134
+ logger.info("Fallback paraphrase %d: %s", i, p)
135
+
136
+ return fallback_paraphrases
137
+
138
+ def _get_single_response(self, query: str, index: int = None) -> str:
139
+ """Get a single response from Mistral API for a query"""
140
+ try:
141
+ query_description = f"Query {index}: {query}" if index is not None else f"Query: {query}"
142
+ logger.info("Getting response for %s", query_description)
143
+ start_time = time.time()
144
+
145
+ messages = [
146
+ {
147
+ "role": "system",
148
+ "content": "You are a helpful AI assistant. Provide accurate, factual information in response to questions."
149
+ },
150
+ {
151
+ "role": "user",
152
+ "content": query
153
+ }
154
+ ]
155
+
156
+ response = self.mistral_client.chat.complete(
157
+ model=self.mistral_model,
158
+ messages=messages
159
+ )
160
+
161
+ result = response.choices[0].message.content
162
+ elapsed_time = time.time() - start_time
163
+
164
+ logger.info("Received response for %s (%.2f seconds)", query_description, elapsed_time)
165
+ logger.debug("Response content for %s: %s", query_description, result[:100] + "..." if len(result) > 100 else result)
166
+
167
+ return result
168
+
169
+ except Exception as e:
170
+ error_msg = f"Error getting response for query '{query}': {e}"
171
+ logger.error(error_msg, exc_info=True)
172
+ return f"Error: Failed to get response for this query."
173
+
174
+ def get_responses(self, queries: List[str]) -> List[str]:
175
+ """Get responses from Mistral API for each query in parallel"""
176
+ logger.info("Getting responses for %d queries in parallel", len(queries))
177
+ start_time = time.time()
178
+
179
+ # Use ThreadPoolExecutor for parallel API calls
180
+ with ThreadPoolExecutor(max_workers=min(len(queries), 5)) as executor:
181
+ # Submit tasks and map them to their original indices
182
+ future_to_index = {
183
+ executor.submit(self._get_single_response, query, i): i
184
+ for i, query in enumerate(queries)
185
+ }
186
+
187
+ # Prepare a list with the correct length
188
+ responses = [""] * len(queries)
189
+
190
+ # Counter for completed responses
191
+ completed_count = 0
192
+
193
+ # Collect results as they complete
194
+ for future in concurrent.futures.as_completed(future_to_index):
195
+ index = future_to_index[future]
196
+ try:
197
+ responses[index] = future.result()
198
+
199
+ # Update completion count and report progress
200
+ completed_count += 1
201
+ if self.progress_callback:
202
+ self.progress_callback("responses_progress",
203
+ completed_responses=completed_count,
204
+ total_responses=len(queries))
205
+
206
+ except Exception as e:
207
+ logger.error("Error processing response for index %d: %s", index, str(e))
208
+ responses[index] = f"Error: Failed to get response for query {index}."
209
+
210
+ # Still update completion count even for errors
211
+ completed_count += 1
212
+ if self.progress_callback:
213
+ self.progress_callback("responses_progress",
214
+ completed_responses=completed_count,
215
+ total_responses=len(queries))
216
+
217
+ elapsed_time = time.time() - start_time
218
+ logger.info("Received all %d responses in %.2f seconds total", len(responses), elapsed_time)
219
+
220
+ return responses
221
+
222
+ def detect_hallucination(self, query: str, n_paraphrases: int = 3) -> Dict:
223
+ """
224
+ Detect hallucinations by comparing responses to paraphrased queries using a judge model
225
+
226
+ Returns:
227
+ Dict containing hallucination judgment and all responses
228
+ """
229
+ logger.info("Starting hallucination detection for query: %s", query)
230
+ start_time = time.time()
231
+
232
+ # Report progress
233
+ if self.progress_callback:
234
+ self.progress_callback("starting", query=query)
235
+
236
+ # Generate paraphrases
237
+ logger.info("Step 1: Generating paraphrases")
238
+ if self.progress_callback:
239
+ self.progress_callback("generating_paraphrases", query=query)
240
+
241
+ all_queries = self.generate_paraphrases(query, n_paraphrases)
242
+
243
+ if self.progress_callback:
244
+ self.progress_callback("paraphrases_complete", query=query, count=len(all_queries))
245
+
246
+ # Get responses to all queries
247
+ logger.info("Step 2: Getting responses to all %d queries", len(all_queries))
248
+ if self.progress_callback:
249
+ self.progress_callback("getting_responses", query=query, total=len(all_queries))
250
+
251
+ all_responses = []
252
+ for i, q in enumerate(all_queries):
253
+ logger.info("Getting response %d/%d for query: %s", i+1, len(all_queries), q)
254
+ if self.progress_callback:
255
+ self.progress_callback("responses_progress", query=query, completed=i, total=len(all_queries))
256
+
257
+ response = self._get_single_response(q, index=i)
258
+ all_responses.append(response)
259
+
260
+ if self.progress_callback:
261
+ self.progress_callback("responses_complete", query=query)
262
+
263
+ # Judge the responses for hallucinations
264
+ logger.info("Step 3: Judging for hallucinations")
265
+ if self.progress_callback:
266
+ self.progress_callback("judging", query=query)
267
+
268
+ # The first query is the original, rest are paraphrases
269
+ original_query = all_queries[0]
270
+ original_response = all_responses[0]
271
+ paraphrased_queries = all_queries[1:] if len(all_queries) > 1 else []
272
+ paraphrased_responses = all_responses[1:] if len(all_responses) > 1 else []
273
+
274
+ # Judge the responses
275
+ judgment = self.judge_hallucination(
276
+ original_query=original_query,
277
+ original_response=original_response,
278
+ paraphrased_queries=paraphrased_queries,
279
+ paraphrased_responses=paraphrased_responses
280
+ )
281
+
282
+ # Assemble the results
283
+ results = {
284
+ "original_query": original_query,
285
+ "original_response": original_response,
286
+ "paraphrased_queries": paraphrased_queries,
287
+ "paraphrased_responses": paraphrased_responses,
288
+ "hallucination_detected": judgment.hallucination_detected,
289
+ "confidence_score": judgment.confidence_score,
290
+ "conflicting_facts": judgment.conflicting_facts,
291
+ "reasoning": judgment.reasoning,
292
+ "summary": judgment.summary
293
+ }
294
+
295
+ # Report completion
296
+ if self.progress_callback:
297
+ self.progress_callback("complete", query=query)
298
+
299
+ logger.info("Hallucination detection completed in %.2f seconds", time.time() - start_time)
300
+ return results
301
+
302
+ def judge_hallucination(self,
303
+ original_query: str,
304
+ original_response: str,
305
+ paraphrased_queries: List[str],
306
+ paraphrased_responses: List[str]) -> HallucinationJudgment:
307
+ """
308
+ Use OpenAI's o3-mini as a judge to detect hallucinations in the responses
309
+ """
310
+ logger.info("Judging hallucinations with OpenAI's %s model", self.openai_model)
311
+ start_time = time.time()
312
+
313
+ # Prepare the context for the judge
314
+ context = f"""
315
+ Original Question: {original_query}
316
+
317
+ Original Response:
318
+ {original_response}
319
+
320
+ Paraphrased Questions and their Responses:
321
+ """
322
+
323
+ for i, (query, response) in enumerate(zip(paraphrased_queries, paraphrased_responses), 1):
324
+ context += f"\nParaphrased Question {i}: {query}\n\nResponse {i}:\n{response}\n"
325
+
326
+ system_prompt = """
327
+ You are a judge evaluating whether an AI is hallucinating across different responses to semantically equivalent questions.
328
+ Analyze all responses carefully to identify any factual inconsistencies or contradictions.
329
+ Focus on factual discrepancies, not stylistic differences.
330
+ A hallucination is when the AI states different facts in response to questions that are asking for the same information.
331
+
332
+ Your response should be a JSON with the following fields:
333
+ - hallucination_detected: boolean indicating whether hallucinations were found
334
+ - confidence_score: number between 0 and 1 representing your confidence in the judgment
335
+ - conflicting_facts: an array of objects describing any conflicting information found
336
+ - reasoning: detailed explanation for your judgment
337
+ - summary: a concise summary of your analysis
338
+ """
339
+
340
+ try:
341
+ logger.info("Sending judgment request to OpenAI API...")
342
+ response = self.openai_client.chat.completions.create(
343
+ model=self.openai_model,
344
+ messages=[
345
+ {"role": "system", "content": system_prompt},
346
+ {"role": "user", "content": f"Evaluate these responses for hallucinations:\n\n{context}"}
347
+ ],
348
+ response_format={"type": "json_object"}
349
+ )
350
+
351
+ result_json = json.loads(response.choices[0].message.content)
352
+ logger.debug("Received judgment response: %s", result_json)
353
+
354
+ # Create the HallucinationJudgment object from the JSON response
355
+ judgment = HallucinationJudgment(
356
+ hallucination_detected=result_json.get("hallucination_detected", False),
357
+ confidence_score=result_json.get("confidence_score", 0.0),
358
+ conflicting_facts=result_json.get("conflicting_facts", []),
359
+ reasoning=result_json.get("reasoning", "No reasoning provided."),
360
+ summary=result_json.get("summary", "No summary provided.")
361
+ )
362
+
363
+ elapsed_time = time.time() - start_time
364
+ logger.info("Judgment completed in %.2f seconds", elapsed_time)
365
+
366
+ return judgment
367
+
368
+ except Exception as e:
369
+ logger.error("Error in hallucination judgment: %s", str(e), exc_info=True)
370
+ # Return a fallback judgment
371
+ return HallucinationJudgment(
372
+ hallucination_detected=False,
373
+ confidence_score=0.0,
374
+ conflicting_facts=[],
375
+ reasoning="Failed to obtain judgment from the model.",
376
+ summary="Analysis failed due to API error."
377
+ )
378
+
379
+
380
+ class HallucinationDetectorApp:
381
+ def __init__(self):
382
+ self.pas2 = None
383
+ # Use the default HF Spaces persistent storage location
384
+ self.data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
385
+ self.db_path = os.path.join(self.data_dir, "feedback.db")
386
+ logger.info("Initializing HallucinationDetectorApp")
387
+ self._initialize_database()
388
+ self.progress_callback = None
389
+
390
+ def _initialize_database(self):
391
+ """Initialize SQLite database for feedback storage in persistent directory"""
392
+ try:
393
+ # Create data directory if it doesn't exist
394
+ os.makedirs(self.data_dir, exist_ok=True)
395
+ logger.info(f"Ensuring data directory exists at {self.data_dir}")
396
+
397
+ conn = sqlite3.connect(self.db_path)
398
+ cursor = conn.cursor()
399
+
400
+ # Create table if it doesn't exist
401
+ cursor.execute('''
402
+ CREATE TABLE IF NOT EXISTS feedback (
403
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
404
+ timestamp TEXT,
405
+ original_query TEXT,
406
+ original_response TEXT,
407
+ paraphrased_queries TEXT,
408
+ paraphrased_responses TEXT,
409
+ hallucination_detected INTEGER,
410
+ confidence_score REAL,
411
+ conflicting_facts TEXT,
412
+ reasoning TEXT,
413
+ summary TEXT,
414
+ user_feedback TEXT
415
+ )
416
+ ''')
417
+
418
+ conn.commit()
419
+ conn.close()
420
+ logger.info(f"Database initialized successfully at {self.db_path}")
421
+ except Exception as e:
422
+ logger.error(f"Error initializing database: {str(e)}", exc_info=True)
423
+ # Fallback to temporary directory if /data is not accessible
424
+ temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp_data")
425
+ os.makedirs(temp_dir, exist_ok=True)
426
+ self.db_path = os.path.join(temp_dir, "feedback.db")
427
+ logger.warning(f"Using fallback database location: {self.db_path}")
428
+
429
+ # Try creating database in fallback location
430
+ try:
431
+ conn = sqlite3.connect(self.db_path)
432
+ cursor = conn.cursor()
433
+ cursor.execute('''
434
+ CREATE TABLE IF NOT EXISTS feedback (
435
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
436
+ timestamp TEXT,
437
+ original_query TEXT,
438
+ original_response TEXT,
439
+ paraphrased_queries TEXT,
440
+ paraphrased_responses TEXT,
441
+ hallucination_detected INTEGER,
442
+ confidence_score REAL,
443
+ conflicting_facts TEXT,
444
+ reasoning TEXT,
445
+ summary TEXT,
446
+ user_feedback TEXT
447
+ )
448
+ ''')
449
+ conn.commit()
450
+ conn.close()
451
+ logger.info(f"Database initialized in fallback location")
452
+ except Exception as fallback_error:
453
+ logger.error(f"Critical error: Could not initialize database in fallback location: {str(fallback_error)}", exc_info=True)
454
+ raise
455
+
456
+ def set_progress_callback(self, callback):
457
+ """Set the progress callback function"""
458
+ self.progress_callback = callback
459
+
460
+ def initialize_api(self, mistral_api_key, openai_api_key):
461
+ """Initialize the PAS2 with API keys"""
462
+ try:
463
+ logger.info("Initializing PAS2 with API keys")
464
+ self.pas2 = PAS2(
465
+ mistral_api_key=mistral_api_key,
466
+ openai_api_key=openai_api_key,
467
+ progress_callback=self.progress_callback
468
+ )
469
+ logger.info("API initialization successful")
470
+ return "API keys set successfully! You can now use the application."
471
+ except Exception as e:
472
+ logger.error("Error initializing API: %s", str(e), exc_info=True)
473
+ return f"Error initializing API: {str(e)}"
474
+
475
+ def process_query(self, query: str):
476
+ """Process the query using PAS2"""
477
+ if not self.pas2:
478
+ logger.error("PAS2 not initialized")
479
+ return {
480
+ "error": "Please set API keys first before processing queries."
481
+ }
482
+
483
+ if not query.strip():
484
+ logger.warning("Empty query provided")
485
+ return {
486
+ "error": "Please enter a query."
487
+ }
488
+
489
+ try:
490
+ # Set the progress callback if needed
491
+ if self.progress_callback and self.pas2.progress_callback != self.progress_callback:
492
+ self.pas2.progress_callback = self.progress_callback
493
+
494
+ # Process the query
495
+ logger.info("Processing query with PAS2: %s", query)
496
+ results = self.pas2.detect_hallucination(query)
497
+ logger.info("Query processing completed successfully")
498
+ return results
499
+ except Exception as e:
500
+ logger.error("Error processing query: %s", str(e), exc_info=True)
501
+ return {
502
+ "error": f"Error processing query: {str(e)}"
503
+ }
504
+
505
+ def save_feedback(self, results, feedback):
506
+ """Save results and user feedback to SQLite database"""
507
+ try:
508
+ logger.info("Saving user feedback: %s", feedback)
509
+
510
+ conn = sqlite3.connect(self.db_path)
511
+ cursor = conn.cursor()
512
+
513
+ # Prepare data
514
+ data = (
515
+ datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
516
+ results.get('original_query', ''),
517
+ results.get('original_response', ''),
518
+ str(results.get('paraphrased_queries', [])),
519
+ str(results.get('paraphrased_responses', [])),
520
+ 1 if results.get('hallucination_detected', False) else 0,
521
+ results.get('confidence_score', 0.0),
522
+ str(results.get('conflicting_facts', [])),
523
+ results.get('reasoning', ''),
524
+ results.get('summary', ''),
525
+ feedback
526
+ )
527
+
528
+ # Insert data
529
+ cursor.execute('''
530
+ INSERT INTO feedback (
531
+ timestamp, original_query, original_response,
532
+ paraphrased_queries, paraphrased_responses,
533
+ hallucination_detected, confidence_score,
534
+ conflicting_facts, reasoning, summary, user_feedback
535
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
536
+ ''', data)
537
+
538
+ conn.commit()
539
+ conn.close()
540
+
541
+ logger.info("Feedback saved successfully to database")
542
+ return "Feedback saved successfully!"
543
+ except Exception as e:
544
+ logger.error("Error saving feedback: %s", str(e), exc_info=True)
545
+ return f"Error saving feedback: {str(e)}"
546
+
547
+ def get_feedback_stats(self):
548
+ """Get statistics about collected feedback"""
549
+ try:
550
+ conn = sqlite3.connect(self.db_path)
551
+ cursor = conn.cursor()
552
+
553
+ # Get total feedback count
554
+ cursor.execute("SELECT COUNT(*) FROM feedback")
555
+ total_count = cursor.fetchone()[0]
556
+
557
+ # Get hallucination detection stats
558
+ cursor.execute("""
559
+ SELECT hallucination_detected, COUNT(*)
560
+ FROM feedback
561
+ GROUP BY hallucination_detected
562
+ """)
563
+ detection_stats = dict(cursor.fetchall())
564
+
565
+ # Get average confidence score
566
+ cursor.execute("SELECT AVG(confidence_score) FROM feedback")
567
+ avg_confidence = cursor.fetchone()[0] or 0
568
+
569
+ conn.close()
570
+
571
+ return {
572
+ "total_feedback": total_count,
573
+ "hallucinations_detected": detection_stats.get(1, 0),
574
+ "no_hallucinations": detection_stats.get(0, 0),
575
+ "average_confidence": round(avg_confidence, 2)
576
+ }
577
+ except Exception as e:
578
+ logger.error("Error getting feedback stats: %s", str(e), exc_info=True)
579
+ return None
580
+
581
+
582
+ # Progress tracking for UI updates
583
+ class ProgressTracker:
584
+ """Tracks progress of hallucination detection for UI updates"""
585
+
586
+ STAGES = {
587
+ "idle": {"status": "Ready", "progress": 0, "color": "#757575"},
588
+ "starting": {"status": "Starting process...", "progress": 5, "color": "#2196F3"},
589
+ "generating_paraphrases": {"status": "Generating paraphrases...", "progress": 15, "color": "#2196F3"},
590
+ "paraphrases_complete": {"status": "Paraphrases generated", "progress": 30, "color": "#2196F3"},
591
+ "getting_responses": {"status": "Getting responses (0/0)...", "progress": 35, "color": "#2196F3"},
592
+ "responses_progress": {"status": "Getting responses ({completed}/{total})...", "progress": 40, "color": "#2196F3"},
593
+ "responses_complete": {"status": "All responses received", "progress": 65, "color": "#2196F3"},
594
+ "judging": {"status": "Analyzing responses for hallucinations...", "progress": 70, "color": "#2196F3"},
595
+ "complete": {"status": "Analysis complete!", "progress": 100, "color": "#4CAF50"},
596
+ "error": {"status": "Error: {error_message}", "progress": 100, "color": "#F44336"}
597
+ }
598
+
599
+ def __init__(self):
600
+ self.stage = "idle"
601
+ self.stage_data = self.STAGES[self.stage].copy()
602
+ self.query = ""
603
+ self.completed_responses = 0
604
+ self.total_responses = 0
605
+ self.error_message = ""
606
+ self._lock = threading.Lock()
607
+ self._status_callback = None
608
+ self._stop_event = threading.Event()
609
+ self._update_thread = None
610
+
611
+ def register_callback(self, callback_fn):
612
+ """Register callback function to update UI"""
613
+ self._status_callback = callback_fn
614
+
615
+ def update_stage(self, stage, **kwargs):
616
+ """Update the current stage and trigger callback"""
617
+ with self._lock:
618
+ if stage in self.STAGES:
619
+ self.stage = stage
620
+ self.stage_data = self.STAGES[stage].copy()
621
+
622
+ # Update with any additional parameters
623
+ for key, value in kwargs.items():
624
+ if key == 'query':
625
+ self.query = value
626
+ elif key == 'completed_responses':
627
+ self.completed_responses = value
628
+ elif key == 'total_responses':
629
+ self.total_responses = value
630
+ elif key == 'error_message':
631
+ self.error_message = value
632
+
633
+ # Format status message
634
+ if stage == 'responses_progress':
635
+ self.stage_data['status'] = self.stage_data['status'].format(
636
+ completed=self.completed_responses,
637
+ total=self.total_responses
638
+ )
639
+ elif stage == 'error':
640
+ self.stage_data['status'] = self.stage_data['status'].format(
641
+ error_message=self.error_message
642
+ )
643
+
644
+ if self._status_callback:
645
+ self._status_callback(self.get_html_status())
646
+
647
+ def get_html_status(self):
648
+ """Get HTML representation of current status"""
649
+ progress_width = f"{self.stage_data['progress']}%"
650
+ status_text = self.stage_data['status']
651
+ color = self.stage_data['color']
652
+
653
+ query_info = f'<div class="query-display">{self.query}</div>' if self.query else ''
654
+
655
+ # Only show status text if not in idle state
656
+ status_display = f'<div class="progress-status" style="color: {color};">{status_text}</div>' if self.stage != "idle" else ''
657
+
658
+ html = f"""
659
+ <div class="progress-container">
660
+ {query_info}
661
+ {status_display}
662
+ <div class="progress-bar-container">
663
+ <div class="progress-bar" style="width: {progress_width}; background-color: {color};"></div>
664
+ </div>
665
+ </div>
666
+ """
667
+ return html
668
+
669
+ def start_pulsing(self):
670
+ """Start a pulsing animation for the progress bar during long operations"""
671
+ if self._update_thread and self._update_thread.is_alive():
672
+ return
673
+
674
+ self._stop_event.clear()
675
+ self._update_thread = threading.Thread(target=self._pulse_progress)
676
+ self._update_thread.daemon = True
677
+ self._update_thread.start()
678
+
679
+ def stop_pulsing(self):
680
+ """Stop the pulsing animation"""
681
+ self._stop_event.set()
682
+ if self._update_thread:
683
+ self._update_thread.join(0.5)
684
+
685
+ def _pulse_progress(self):
686
+ """Animate the progress bar to show activity"""
687
+ pulse_stages = ["⋯", "⋯⋯", "⋯⋯⋯", "⋯⋯", "⋯"]
688
+ i = 0
689
+ while not self._stop_event.is_set():
690
+ with self._lock:
691
+ if self.stage not in ["idle", "complete", "error"]:
692
+ status_base = self.stage_data['status'].split("...")[0] if "..." in self.stage_data['status'] else self.stage_data['status']
693
+ self.stage_data['status'] = f"{status_base}... {pulse_stages[i]}"
694
+
695
+ if self._status_callback:
696
+ self._status_callback(self.get_html_status())
697
+
698
+ i = (i + 1) % len(pulse_stages)
699
+ time.sleep(0.3)
700
+
701
+
702
+ def create_interface():
703
+ """Create Gradio interface"""
704
+ detector = HallucinationDetectorApp()
705
+
706
+ # Initialize Progress Tracker
707
+ progress_tracker = ProgressTracker()
708
+
709
+ # Initialize APIs from environment variables automatically
710
+ try:
711
+ detector.initialize_api(
712
+ mistral_api_key=os.environ.get("HF_MISTRAL_API_KEY"),
713
+ openai_api_key=os.environ.get("HF_OPENAI_API_KEY")
714
+ )
715
+ except Exception as e:
716
+ print(f"Warning: Failed to initialize APIs from environment variables: {e}")
717
+ print("Please make sure HF_MISTRAL_API_KEY and HF_OPENAI_API_KEY are set in your environment")
718
+
719
+ # CSS for styling
720
+ css = """
721
+ .container {
722
+ max-width: 1000px;
723
+ margin: 0 auto;
724
+ }
725
+ .title {
726
+ text-align: center;
727
+ margin-bottom: 0.5em;
728
+ color: #1a237e;
729
+ font-weight: 600;
730
+ }
731
+ .subtitle {
732
+ text-align: center;
733
+ margin-bottom: 1.5em;
734
+ color: #455a64;
735
+ font-size: 1.2em;
736
+ }
737
+ .section-title {
738
+ margin-top: 1em;
739
+ margin-bottom: 0.5em;
740
+ font-weight: bold;
741
+ color: #283593;
742
+ }
743
+ .info-box {
744
+ padding: 1.2em;
745
+ border-radius: 8px;
746
+ background-color: #f5f5f5;
747
+ margin-bottom: 1em;
748
+ box-shadow: 0 2px 5px rgba(0,0,0,0.05);
749
+ }
750
+ .hallucination-positive {
751
+ padding: 1.2em;
752
+ border-radius: 8px;
753
+ background-color: #ffebee;
754
+ border-left: 5px solid #f44336;
755
+ margin-bottom: 1em;
756
+ box-shadow: 0 2px 5px rgba(0,0,0,0.05);
757
+ }
758
+ .hallucination-negative {
759
+ padding: 1.2em;
760
+ border-radius: 8px;
761
+ background-color: #e8f5e9;
762
+ border-left: 5px solid #4caf50;
763
+ margin-bottom: 1em;
764
+ box-shadow: 0 2px 5px rgba(0,0,0,0.05);
765
+ }
766
+ .response-box {
767
+ padding: 1.2em;
768
+ border-radius: 8px;
769
+ background-color: #f5f5f5;
770
+ margin-bottom: 0.8em;
771
+ box-shadow: 0 2px 5px rgba(0,0,0,0.05);
772
+ }
773
+ .example-queries {
774
+ display: flex;
775
+ flex-wrap: wrap;
776
+ gap: 8px;
777
+ margin-bottom: 15px;
778
+ }
779
+ .example-query {
780
+ background-color: #e3f2fd;
781
+ padding: 8px 15px;
782
+ border-radius: 18px;
783
+ font-size: 0.9em;
784
+ cursor: pointer;
785
+ transition: all 0.2s;
786
+ border: 1px solid #bbdefb;
787
+ }
788
+ .example-query:hover {
789
+ background-color: #bbdefb;
790
+ box-shadow: 0 2px 5px rgba(0,0,0,0.1);
791
+ }
792
+ .stats-section {
793
+ display: flex;
794
+ justify-content: space-between;
795
+ background-color: #e8eaf6;
796
+ padding: 15px;
797
+ border-radius: 8px;
798
+ margin-bottom: 20px;
799
+ }
800
+ .stat-item {
801
+ text-align: center;
802
+ padding: 10px;
803
+ }
804
+ .stat-value {
805
+ font-size: 1.5em;
806
+ font-weight: bold;
807
+ color: #303f9f;
808
+ }
809
+ .stat-label {
810
+ font-size: 0.9em;
811
+ color: #5c6bc0;
812
+ }
813
+ .feedback-section {
814
+ border-top: 1px solid #e0e0e0;
815
+ padding-top: 15px;
816
+ margin-top: 20px;
817
+ }
818
+ footer {
819
+ text-align: center;
820
+ padding: 20px;
821
+ margin-top: 30px;
822
+ color: #9e9e9e;
823
+ font-size: 0.9em;
824
+ }
825
+ .processing-status {
826
+ padding: 12px;
827
+ background-color: #fff3e0;
828
+ border-left: 4px solid #ff9800;
829
+ margin-bottom: 15px;
830
+ font-weight: 500;
831
+ color: #e65100;
832
+ }
833
+ .debug-panel {
834
+ background-color: #f5f5f5;
835
+ border: 1px solid #e0e0e0;
836
+ border-radius: 4px;
837
+ padding: 10px;
838
+ margin-top: 15px;
839
+ font-family: monospace;
840
+ font-size: 0.9em;
841
+ white-space: pre-wrap;
842
+ max-height: 200px;
843
+ overflow-y: auto;
844
+ }
845
+ .progress-container {
846
+ padding: 15px;
847
+ background-color: #fff;
848
+ border-radius: 8px;
849
+ box-shadow: 0 2px 5px rgba(0,0,0,0.05);
850
+ margin-bottom: 15px;
851
+ }
852
+ .progress-status {
853
+ font-weight: 500;
854
+ margin-bottom: 8px;
855
+ padding: 4px 0;
856
+ font-size: 0.95em;
857
+ }
858
+ .progress-bar-container {
859
+ background-color: #e0e0e0;
860
+ height: 10px;
861
+ border-radius: 5px;
862
+ overflow: hidden;
863
+ margin-bottom: 10px;
864
+ box-shadow: inset 0 1px 3px rgba(0,0,0,0.1);
865
+ }
866
+ .progress-bar {
867
+ height: 100%;
868
+ transition: width 0.5s ease;
869
+ background-image: linear-gradient(to right, #2196F3, #3f51b5);
870
+ }
871
+ .query-display {
872
+ font-style: italic;
873
+ color: #666;
874
+ margin-bottom: 10px;
875
+ background-color: #f5f5f5;
876
+ padding: 8px;
877
+ border-radius: 4px;
878
+ border-left: 3px solid #2196F3;
879
+ }
880
+ """
881
+
882
+ # Example queries
883
+ example_queries = [
884
+ "Who was the first person to land on the moon?",
885
+ "What is the capital of France?",
886
+ "How many planets are in our solar system?",
887
+ "Who wrote the novel 1984?",
888
+ "What is the speed of light?",
889
+ "What was the first computer?"
890
+ ]
891
+
892
+ # Function to update the progress display
893
+ def update_progress_display(html):
894
+ """Update the progress display with the provided HTML"""
895
+ return gr.update(visible=True, value=html)
896
+
897
+ # Register the callback with the tracker
898
+ progress_tracker.register_callback(update_progress_display)
899
+
900
+ # Register the tracker with the detector
901
+ detector.set_progress_callback(progress_tracker.update_stage)
902
+
903
+ # Helper function to set example query
904
+ def set_example_query(example):
905
+ return example
906
+
907
+ # Function to show processing is starting
908
+ def start_processing(query):
909
+ logger.info("Processing query: %s", query)
910
+ # Stop any existing pulsing to prepare for incremental progress updates
911
+ progress_tracker.stop_pulsing()
912
+
913
+ # Reset to a processing state without the "Ready" text
914
+ # Use "starting" stage but with minimal UI display
915
+ progress_tracker.stage = "starting"
916
+ progress_tracker.query = query
917
+
918
+ # Force UI update with clean display
919
+ if progress_tracker._status_callback:
920
+ progress_tracker._status_callback(progress_tracker.get_html_status())
921
+
922
+ return [
923
+ gr.update(visible=True), # Show the progress display
924
+ gr.update(visible=False), # Hide the results accordion
925
+ gr.update(visible=False), # Hide the feedback accordion
926
+ None # Reset hidden results
927
+ ]
928
+
929
+ # Main processing function
930
+ def process_query_and_display_results(query, progress=gr.Progress()):
931
+ if not query.strip():
932
+ logger.warning("Empty query submitted")
933
+ progress_tracker.stop_pulsing()
934
+ progress_tracker.update_stage("error", error_message="Please enter a query.")
935
+ return [
936
+ gr.update(visible=True), # Show the progress with error
937
+ gr.update(visible=False),
938
+ gr.update(visible=False),
939
+ None
940
+ ]
941
+
942
+ # Check if API is initialized
943
+ if not detector.pas2:
944
+ try:
945
+ # Try to initialize from environment variables
946
+ logger.info("Initializing APIs from environment variables")
947
+ progress(0.05, desc="Initializing API...")
948
+ init_message = detector.initialize_api(
949
+ mistral_api_key=os.environ.get("HF_MISTRAL_API_KEY"),
950
+ openai_api_key=os.environ.get("HF_OPENAI_API_KEY")
951
+ )
952
+ if "successfully" not in init_message:
953
+ logger.error("Failed to initialize APIs: %s", init_message)
954
+ progress_tracker.stop_pulsing()
955
+ progress_tracker.update_stage("error", error_message="API keys not found in environment variables.")
956
+ return [
957
+ gr.update(visible=True),
958
+ gr.update(visible=False),
959
+ gr.update(visible=False),
960
+ None
961
+ ]
962
+ except Exception as e:
963
+ logger.error("Error initializing API: %s", str(e), exc_info=True)
964
+ progress_tracker.stop_pulsing()
965
+ progress_tracker.update_stage("error", error_message=f"Error initializing API: {str(e)}")
966
+ return [
967
+ gr.update(visible=True),
968
+ gr.update(visible=False),
969
+ gr.update(visible=False),
970
+ None
971
+ ]
972
+
973
+ try:
974
+ # Process the query
975
+ logger.info("Starting hallucination detection process")
976
+ start_time = time.time()
977
+
978
+ # Set up a custom progress callback that uses both the progress_tracker and the gr.Progress
979
+ def combined_progress_callback(stage, **kwargs):
980
+ # Skip the idle stage, which shows "Ready"
981
+ if stage == "idle":
982
+ return
983
+
984
+ progress_tracker.update_stage(stage, **kwargs)
985
+
986
+ # Map the stages to progress values for the gr.Progress bar
987
+ stage_to_progress = {
988
+ "starting": 0.05,
989
+ "generating_paraphrases": 0.15,
990
+ "paraphrases_complete": 0.3,
991
+ "getting_responses": 0.35,
992
+ "responses_progress": lambda kwargs: 0.35 + (0.3 * (kwargs.get("completed", 0) / max(kwargs.get("total", 1), 1))),
993
+ "responses_complete": 0.65,
994
+ "judging": 0.7,
995
+ "complete": 1.0,
996
+ "error": 1.0
997
+ }
998
+
999
+ # Update the gr.Progress bar
1000
+ if stage in stage_to_progress:
1001
+ prog_value = stage_to_progress[stage]
1002
+ if callable(prog_value):
1003
+ prog_value = prog_value(kwargs)
1004
+
1005
+ desc = progress_tracker.STAGES[stage]["status"]
1006
+ if "{" in desc and "}" in desc:
1007
+ # Format the description with any kwargs
1008
+ desc = desc.format(**kwargs)
1009
+
1010
+ # Ensure UI updates by adding a small delay
1011
+ # This forces the progress updates to be rendered
1012
+ progress(prog_value, desc=desc)
1013
+
1014
+ # For certain key stages, add a small sleep to ensure progress is visible
1015
+ if stage in ["starting", "generating_paraphrases", "paraphrases_complete",
1016
+ "getting_responses", "responses_complete", "judging", "complete"]:
1017
+ time.sleep(0.2) # Small delay to ensure UI update is visible
1018
+
1019
+ # Use these steps for processing
1020
+ detector.set_progress_callback(combined_progress_callback)
1021
+
1022
+ # Create a wrapper function for detect_hallucination that gives more control over progress updates
1023
+ def run_detection_with_visible_progress():
1024
+ # Step 1: Start
1025
+ combined_progress_callback("starting", query=query)
1026
+ time.sleep(0.3) # Ensure starting status is visible
1027
+
1028
+ # Step 2: Generate paraphrases (15-30%)
1029
+ combined_progress_callback("generating_paraphrases", query=query)
1030
+ all_queries = detector.pas2.generate_paraphrases(query)
1031
+ combined_progress_callback("paraphrases_complete", query=query, count=len(all_queries))
1032
+
1033
+ # Step 3: Get responses (35-65%)
1034
+ combined_progress_callback("getting_responses", query=query, total=len(all_queries))
1035
+ all_responses = []
1036
+ for i, q in enumerate(all_queries):
1037
+ # Show incremental progress for each response
1038
+ combined_progress_callback("responses_progress", query=query, completed=i, total=len(all_queries))
1039
+ response = detector.pas2._get_single_response(q, index=i)
1040
+ all_responses.append(response)
1041
+ combined_progress_callback("responses_complete", query=query)
1042
+
1043
+ # Step 4: Judge hallucinations (70-100%)
1044
+ combined_progress_callback("judging", query=query)
1045
+
1046
+ # The first query is the original, rest are paraphrases
1047
+ original_query = all_queries[0]
1048
+ original_response = all_responses[0]
1049
+ paraphrased_queries = all_queries[1:] if len(all_queries) > 1 else []
1050
+ paraphrased_responses = all_responses[1:] if len(all_responses) > 1 else []
1051
+
1052
+ # Judge the responses
1053
+ judgment = detector.pas2.judge_hallucination(
1054
+ original_query=original_query,
1055
+ original_response=original_response,
1056
+ paraphrased_queries=paraphrased_queries,
1057
+ paraphrased_responses=paraphrased_responses
1058
+ )
1059
+
1060
+ # Assemble the results
1061
+ results = {
1062
+ "original_query": original_query,
1063
+ "original_response": original_response,
1064
+ "paraphrased_queries": paraphrased_queries,
1065
+ "paraphrased_responses": paraphrased_responses,
1066
+ "hallucination_detected": judgment.hallucination_detected,
1067
+ "confidence_score": judgment.confidence_score,
1068
+ "conflicting_facts": judgment.conflicting_facts,
1069
+ "reasoning": judgment.reasoning,
1070
+ "summary": judgment.summary
1071
+ }
1072
+
1073
+ # Show completion
1074
+ combined_progress_callback("complete", query=query)
1075
+ time.sleep(0.3) # Ensure complete status is visible
1076
+
1077
+ return results
1078
+
1079
+ # Run the detection process with visible progress
1080
+ results = run_detection_with_visible_progress()
1081
+
1082
+ # Calculate elapsed time
1083
+ elapsed_time = time.time() - start_time
1084
+ logger.info("Hallucination detection completed in %.2f seconds", elapsed_time)
1085
+
1086
+ # Check for errors
1087
+ if "error" in results:
1088
+ logger.error("Error in results: %s", results["error"])
1089
+ progress_tracker.stop_pulsing()
1090
+ progress_tracker.update_stage("error", error_message=results["error"])
1091
+ return [
1092
+ gr.update(visible=True),
1093
+ gr.update(visible=False),
1094
+ gr.update(visible=False),
1095
+ None
1096
+ ]
1097
+
1098
+ # Prepare responses for display
1099
+ original_query = results["original_query"]
1100
+ original_response = results["original_response"]
1101
+
1102
+ paraphrased_queries = results["paraphrased_queries"]
1103
+ paraphrased_responses = results["paraphrased_responses"]
1104
+
1105
+ hallucination_detected = results["hallucination_detected"]
1106
+ confidence = results["confidence_score"]
1107
+ reasoning = results["reasoning"]
1108
+ summary = results["summary"]
1109
+
1110
+ # Format conflicting facts
1111
+ conflicting_facts = results["conflicting_facts"]
1112
+ conflicting_facts_text = ""
1113
+ if conflicting_facts:
1114
+ for i, fact in enumerate(conflicting_facts, 1):
1115
+ conflicting_facts_text += f"{i}. "
1116
+ if isinstance(fact, dict):
1117
+ for key, value in fact.items():
1118
+ conflicting_facts_text += f"{key}: {value}, "
1119
+ conflicting_facts_text = conflicting_facts_text.rstrip(", ")
1120
+ else:
1121
+ conflicting_facts_text += str(fact)
1122
+ conflicting_facts_text += "\n"
1123
+
1124
+ # Format responses to escape any backslashes
1125
+ original_response_safe = original_response.replace('\\', '\\\\').replace('\n', '<br>')
1126
+ paraphrased_responses_safe = [r.replace('\\', '\\\\').replace('\n', '<br>') for r in paraphrased_responses]
1127
+ reasoning_safe = reasoning.replace('\\', '\\\\').replace('\n', '<br>')
1128
+ conflicting_facts_text_safe = conflicting_facts_text.replace('\\', '\\\\').replace('\n', '<br>') if conflicting_facts_text else "None identified"
1129
+
1130
+ html_output = f"""
1131
+ <div class="container">
1132
+ <h2 class="title">Hallucination Detection Results</h2>
1133
+
1134
+ <div class="stats-section">
1135
+ <div class="stat-item">
1136
+ <div class="stat-value">{'Yes' if hallucination_detected else 'No'}</div>
1137
+ <div class="stat-label">Hallucination Detected</div>
1138
+ </div>
1139
+ <div class="stat-item">
1140
+ <div class="stat-value">{confidence:.2f}</div>
1141
+ <div class="stat-label">Confidence Score</div>
1142
+ </div>
1143
+ <div class="stat-item">
1144
+ <div class="stat-value">{len(paraphrased_queries)}</div>
1145
+ <div class="stat-label">Paraphrases Analyzed</div>
1146
+ </div>
1147
+ <div class="stat-item">
1148
+ <div class="stat-value">{elapsed_time:.1f}s</div>
1149
+ <div class="stat-label">Processing Time</div>
1150
+ </div>
1151
+ </div>
1152
+
1153
+ <div class="{'hallucination-positive' if hallucination_detected else 'hallucination-negative'}">
1154
+ <h3>Analysis Summary</h3>
1155
+ <p>{summary}</p>
1156
+ </div>
1157
+
1158
+ <div class="section-title">Original Query</div>
1159
+ <div class="response-box">
1160
+ {original_query}
1161
+ </div>
1162
+
1163
+ <div class="section-title">Original Response</div>
1164
+ <div class="response-box">
1165
+ {original_response_safe}
1166
+ </div>
1167
+
1168
+ <div class="section-title">Paraphrased Queries and Responses</div>
1169
+ """
1170
+
1171
+ for i, (q, r) in enumerate(zip(paraphrased_queries, paraphrased_responses_safe), 1):
1172
+ html_output += f"""
1173
+ <div class="section-title">Paraphrased Query {i}</div>
1174
+ <div class="response-box">
1175
+ {q}
1176
+ </div>
1177
+
1178
+ <div class="section-title">Response {i}</div>
1179
+ <div class="response-box">
1180
+ {r}
1181
+ </div>
1182
+ """
1183
+
1184
+ html_output += f"""
1185
+ <div class="section-title">Detailed Analysis</div>
1186
+ <div class="info-box">
1187
+ <p><strong>Reasoning:</strong></p>
1188
+ <p>{reasoning_safe}</p>
1189
+
1190
+ <p><strong>Conflicting Facts:</strong></p>
1191
+ <p>{conflicting_facts_text_safe}</p>
1192
+ </div>
1193
+ </div>
1194
+ """
1195
+
1196
+ logger.info("Updating UI with results")
1197
+ progress_tracker.stop_pulsing()
1198
+
1199
+ return [
1200
+ gr.update(visible=False), # Hide progress display when showing results
1201
+ gr.update(visible=True, value=html_output),
1202
+ gr.update(visible=True),
1203
+ results
1204
+ ]
1205
+
1206
+ except Exception as e:
1207
+ logger.error("Error processing query: %s", str(e), exc_info=True)
1208
+ progress_tracker.stop_pulsing()
1209
+ progress_tracker.update_stage("error", error_message=f"Error processing query: {str(e)}")
1210
+ return [
1211
+ gr.update(visible=True),
1212
+ gr.update(visible=False),
1213
+ gr.update(visible=False),
1214
+ None
1215
+ ]
1216
+
1217
+ # Helper function to submit feedback and update stats
1218
+ def combine_feedback(fb_input, fb_text, results):
1219
+ combined_feedback = f"{fb_input}: {fb_text}" if fb_text else fb_input
1220
+ if not results:
1221
+ return "No results to attach feedback to.", ""
1222
+
1223
+ response = detector.save_feedback(results, combined_feedback)
1224
+
1225
+ # Get updated stats
1226
+ stats = detector.get_feedback_stats()
1227
+ if stats:
1228
+ stats_html = f"""
1229
+ <div class="stats-section" style="margin-top: 15px;">
1230
+ <div class="stat-item">
1231
+ <div class="stat-value">{stats['total_feedback']}</div>
1232
+ <div class="stat-label">Total Feedback</div>
1233
+ </div>
1234
+ <div class="stat-item">
1235
+ <div class="stat-value">{stats['hallucinations_detected']}</div>
1236
+ <div class="stat-label">Hallucinations Found</div>
1237
+ </div>
1238
+ <div class="stat-item">
1239
+ <div class="stat-value">{stats['no_hallucinations']}</div>
1240
+ <div class="stat-label">No Hallucinations</div>
1241
+ </div>
1242
+ <div class="stat-item">
1243
+ <div class="stat-value">{stats['average_confidence']}</div>
1244
+ <div class="stat-label">Avg. Confidence</div>
1245
+ </div>
1246
+ </div>
1247
+ """
1248
+ else:
1249
+ stats_html = ""
1250
+
1251
+ return response, stats_html
1252
+
1253
+ # Create the interface
1254
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as interface:
1255
+ gr.HTML(
1256
+ """
1257
+ <div style="text-align: center; margin-bottom: 1.5rem">
1258
+ <h1 style="font-size: 2.2em; font-weight: 600; color: #1a237e; margin-bottom: 0.2em;">PAS2 - Hallucination Detector</h1>
1259
+ <h3 style="font-size: 1.3em; color: #455a64; margin-bottom: 0.8em;">Advanced AI Response Verification Using Model-as-Judge</h3>
1260
+ <p style="font-size: 1.1em; color: #546e7a; max-width: 800px; margin: 0 auto;">
1261
+ This tool detects hallucinations in AI responses by comparing answers to semantically equivalent questions and using a specialized judge model.
1262
+ </p>
1263
+ </div>
1264
+ """
1265
+ )
1266
+
1267
+ with gr.Accordion("About this Tool", open=False):
1268
+ gr.Markdown(
1269
+ """
1270
+ ### How It Works
1271
+
1272
+ This tool implements the Paraphrase-based Approach for Scrutinizing Systems (PAS2) with a model-as-judge enhancement:
1273
+
1274
+ 1. **Paraphrase Generation**: Your question is paraphrased multiple ways while preserving its core meaning
1275
+ 2. **Multiple Responses**: All questions (original + paraphrases) are sent to Mistral Large model
1276
+ 3. **Expert Judgment**: OpenAI's o3-mini analyzes all responses to detect factual inconsistencies
1277
+
1278
+ ### Why This Approach?
1279
+
1280
+ When an AI hallucinates, it often provides different answers to the same question when phrased differently.
1281
+ By using a separate judge model, we can identify these inconsistencies more effectively than with
1282
+ metric-based approaches.
1283
+
1284
+ ### Understanding the Results
1285
+
1286
+ - **Confidence Score**: Indicates the judge's confidence in the hallucination detection
1287
+ - **Conflicting Facts**: Specific inconsistencies found across responses
1288
+ - **Reasoning**: The judge's detailed analysis explaining its decision
1289
+
1290
+ ### Privacy Notice
1291
+
1292
+ Your queries and the system's responses are saved to help improve hallucination detection.
1293
+ No personally identifiable information is collected.
1294
+ """
1295
+ )
1296
+
1297
+ with gr.Row():
1298
+ with gr.Column():
1299
+ # First define the query input
1300
+ gr.Markdown("### Enter Your Question")
1301
+ with gr.Row():
1302
+ query_input = gr.Textbox(
1303
+ label="",
1304
+ placeholder="Ask a factual question (e.g., Who was the first person to land on the moon?)",
1305
+ lines=3
1306
+ )
1307
+
1308
+ # Now define the example queries
1309
+ gr.Markdown("### Or Try an Example")
1310
+ example_row = gr.Row()
1311
+ with example_row:
1312
+ for example in example_queries:
1313
+ example_btn = gr.Button(
1314
+ example,
1315
+ elem_classes=["example-query"],
1316
+ scale=0
1317
+ )
1318
+ example_btn.click(
1319
+ fn=set_example_query,
1320
+ inputs=[gr.Textbox(value=example, visible=False)],
1321
+ outputs=[query_input]
1322
+ )
1323
+
1324
+ with gr.Row():
1325
+ submit_button = gr.Button("Detect Hallucinations", variant="primary", scale=1)
1326
+
1327
+ # Error message
1328
+ error_message = gr.HTML(
1329
+ label="Status",
1330
+ visible=False
1331
+ )
1332
+
1333
+ # Progress display
1334
+ progress_display = gr.HTML(
1335
+ value=progress_tracker.get_html_status(),
1336
+ visible=True
1337
+ )
1338
+
1339
+ # Results display
1340
+ results_accordion = gr.HTML(visible=False)
1341
+
1342
+ # Add feedback stats display
1343
+ feedback_stats = gr.HTML(visible=True)
1344
+
1345
+ # Feedback section
1346
+ with gr.Accordion("Provide Feedback", open=False, visible=False) as feedback_accordion:
1347
+ gr.Markdown("### Help Improve the System")
1348
+ gr.Markdown("Your feedback helps us refine the hallucination detection system.")
1349
+
1350
+ feedback_input = gr.Radio(
1351
+ label="Is the hallucination detection accurate?",
1352
+ choices=["Yes, correct detection", "No, incorrectly flagged hallucination", "No, missed hallucination", "Unsure/Other"],
1353
+ value="Yes, correct detection"
1354
+ )
1355
+
1356
+ feedback_text = gr.Textbox(
1357
+ label="Additional comments (optional)",
1358
+ placeholder="Please provide any additional observations or details...",
1359
+ lines=2
1360
+ )
1361
+
1362
+ feedback_button = gr.Button("Submit Feedback", variant="secondary")
1363
+ feedback_status = gr.Textbox(label="Feedback Status", interactive=False, visible=False)
1364
+
1365
+ # Initialize feedback stats
1366
+ initial_stats = detector.get_feedback_stats()
1367
+ if initial_stats:
1368
+ feedback_stats.value = f"""
1369
+ <div class="stats-section">
1370
+ <div class="stat-item">
1371
+ <div class="stat-value">{initial_stats['total_feedback']}</div>
1372
+ <div class="stat-label">Total Feedback</div>
1373
+ </div>
1374
+ <div class="stat-item">
1375
+ <div class="stat-value">{initial_stats['hallucinations_detected']}</div>
1376
+ <div class="stat-label">Hallucinations Found</div>
1377
+ </div>
1378
+ <div class="stat-item">
1379
+ <div class="stat-value">{initial_stats['no_hallucinations']}</div>
1380
+ <div class="stat-label">No Hallucinations</div>
1381
+ </div>
1382
+ <div class="stat-item">
1383
+ <div class="stat-value">{initial_stats['average_confidence']}</div>
1384
+ <div class="stat-label">Avg. Confidence</div>
1385
+ </div>
1386
+ </div>
1387
+ """
1388
+
1389
+ # Hidden state to store results for feedback
1390
+ hidden_results = gr.State()
1391
+
1392
+ # Set up event handlers
1393
+ submit_button.click(
1394
+ fn=start_processing,
1395
+ inputs=[query_input],
1396
+ outputs=[progress_display, results_accordion, feedback_accordion, hidden_results],
1397
+ queue=False
1398
+ ).then(
1399
+ fn=process_query_and_display_results,
1400
+ inputs=[query_input],
1401
+ outputs=[progress_display, results_accordion, feedback_accordion, hidden_results]
1402
+ )
1403
+
1404
+ feedback_button.click(
1405
+ fn=combine_feedback,
1406
+ inputs=[feedback_input, feedback_text, hidden_results],
1407
+ outputs=[feedback_status, feedback_stats]
1408
+ )
1409
+
1410
+ # Footer
1411
+ gr.HTML(
1412
+ """
1413
+ <footer>
1414
+ <p>Paraphrase-based Approach for Scrutinizing Systems (PAS2) - Advanced Hallucination Detection</p>
1415
+ <p>Using Mistral Large for generation and OpenAI o3-mini as judge</p>
1416
+ </footer>
1417
+ """
1418
+ )
1419
+
1420
+ return interface
1421
+
1422
+ # Add a test function to demonstrate progress bar in isolation
1423
+ def test_progress():
1424
+ """Simple test function to demonstrate progress bar"""
1425
+ import gradio as gr
1426
+ import time
1427
+
1428
+ def slow_process(progress=gr.Progress()):
1429
+ progress(0, desc="Starting process...")
1430
+ time.sleep(0.5)
1431
+
1432
+ # Phase 1: Generating paraphrases
1433
+ progress(0.15, desc="Generating paraphrases...")
1434
+ time.sleep(1)
1435
+ progress(0.3, desc="Paraphrases generated")
1436
+ time.sleep(0.5)
1437
+
1438
+ # Phase 2: Getting responses
1439
+ progress(0.35, desc="Getting responses...")
1440
+ # Show incremental progress for responses
1441
+ for i in range(3):
1442
+ time.sleep(0.8)
1443
+ prog = 0.35 + (0.3 * ((i+1) / 3))
1444
+ progress(prog, desc=f"Getting responses ({i+1}/3)...")
1445
+
1446
+ progress(0.65, desc="All responses received")
1447
+ time.sleep(0.5)
1448
+
1449
+ # Phase 3: Analyzing
1450
+ progress(0.7, desc="Analyzing responses for hallucinations...")
1451
+ time.sleep(2)
1452
+
1453
+ # Complete
1454
+ progress(1.0, desc="Analysis complete!")
1455
+ return "Process completed successfully!"
1456
+
1457
+ with gr.Blocks() as demo:
1458
+ with gr.Row():
1459
+ btn = gr.Button("Start Process")
1460
+ output = gr.Textbox(label="Result")
1461
+
1462
+ btn.click(fn=slow_process, outputs=output)
1463
+
1464
+ demo.launch()
1465
+
1466
+ # Main application entry point
1467
+ if __name__ == "__main__":
1468
+ logger.info("Starting PAS2 Hallucination Detector")
1469
+ interface = create_interface()
1470
+ logger.info("Launching Gradio interface...")
1471
+ interface.launch(
1472
+ server_name="0.0.0.0", # Bind to all interfaces
1473
+ server_port=7860, # Default Hugging Face Spaces port
1474
+ show_api=False,
1475
+ quiet=True, # Changed to True for Hugging Face deployment
1476
+ share=False,
1477
+ max_threads=10,
1478
+ debug=False # Changed to False for production deployment
1479
+ )
1480
+
1481
+ # Uncomment this line to run the test function instead of the main interface
1482
+ # if __name__ == "__main__":
1483
+ # test_progress()
pas2.py DELETED
@@ -1,1483 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import pandas as pd
4
- from datetime import datetime
5
- from pydantic import BaseModel, Field
6
- from typing import List, Dict, Any, Optional
7
- import numpy as np
8
- from mistralai import Mistral
9
- from openai import OpenAI
10
- import re
11
- import json
12
- import logging
13
- import time
14
- import concurrent.futures
15
- from concurrent.futures import ThreadPoolExecutor
16
- import threading
17
- import sqlite3
18
-
19
- # Configure logging
20
- logging.basicConfig(
21
- level=logging.INFO,
22
- format='%(asctime)s [%(levelname)s] %(message)s',
23
- handlers=[
24
- logging.StreamHandler()
25
- ]
26
- )
27
-
28
- logger = logging.getLogger(__name__)
29
-
30
- class HallucinationJudgment(BaseModel):
31
- hallucination_detected: bool = Field(description="Whether a hallucination is detected across the responses")
32
- confidence_score: float = Field(description="Confidence score between 0-1 for the hallucination judgment")
33
- conflicting_facts: List[Dict[str, Any]] = Field(description="List of conflicting facts found in the responses")
34
- reasoning: str = Field(description="Detailed reasoning for the judgment")
35
- summary: str = Field(description="A summary of the analysis")
36
-
37
- class PAS2:
38
- """Paraphrase-based Approach for Scrutinizing Systems - Using model-as-judge"""
39
-
40
- def __init__(self, mistral_api_key=None, openai_api_key=None, progress_callback=None):
41
- """Initialize the PAS2 with API keys"""
42
- # For Hugging Face Spaces, we prioritize getting API keys from HF_* environment variables
43
- # which are set from the Secrets tab in the Space settings
44
- self.mistral_api_key = mistral_api_key or os.environ.get("HF_MISTRAL_API_KEY") or os.environ.get("MISTRAL_API_KEY")
45
- self.openai_api_key = openai_api_key or os.environ.get("HF_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY")
46
- self.progress_callback = progress_callback
47
-
48
- if not self.mistral_api_key:
49
- raise ValueError("Mistral API key is required. Set it via HF_MISTRAL_API_KEY in Hugging Face Spaces secrets or pass it as a parameter.")
50
-
51
- if not self.openai_api_key:
52
- raise ValueError("OpenAI API key is required. Set it via HF_OPENAI_API_KEY in Hugging Face Spaces secrets or pass it as a parameter.")
53
-
54
- self.mistral_client = Mistral(api_key=self.mistral_api_key)
55
- self.openai_client = OpenAI(api_key=self.openai_api_key)
56
-
57
- self.mistral_model = "mistral-large-latest"
58
- self.openai_model = "o3-mini"
59
-
60
- logger.info("PAS2 initialized with Mistral model: %s and OpenAI model: %s",
61
- self.mistral_model, self.openai_model)
62
-
63
- def generate_paraphrases(self, query: str, n_paraphrases: int = 3) -> List[str]:
64
- """Generate paraphrases of the input query using Mistral API"""
65
- logger.info("Generating %d paraphrases for query: %s", n_paraphrases, query)
66
- start_time = time.time()
67
-
68
- messages = [
69
- {
70
- "role": "system",
71
- "content": f"You are an expert at creating semantically equivalent paraphrases. Generate {n_paraphrases} different paraphrases of the given query that preserve the original meaning but vary in wording and structure. Return a JSON array of strings, each containing one paraphrase."
72
- },
73
- {
74
- "role": "user",
75
- "content": query
76
- }
77
- ]
78
-
79
- try:
80
- logger.info("Sending paraphrase generation request to Mistral API...")
81
- response = self.mistral_client.chat.complete(
82
- model=self.mistral_model,
83
- messages=messages,
84
- response_format={"type": "json_object"}
85
- )
86
-
87
- content = response.choices[0].message.content
88
- logger.debug("Received raw paraphrase response: %s", content)
89
-
90
- paraphrases_data = json.loads(content)
91
-
92
- # Handle different possible JSON structures
93
- if isinstance(paraphrases_data, dict) and "paraphrases" in paraphrases_data:
94
- paraphrases = paraphrases_data["paraphrases"]
95
- elif isinstance(paraphrases_data, dict) and "results" in paraphrases_data:
96
- paraphrases = paraphrases_data["results"]
97
- elif isinstance(paraphrases_data, list):
98
- paraphrases = paraphrases_data
99
- else:
100
- # Try to extract a list from any field
101
- for key, value in paraphrases_data.items():
102
- if isinstance(value, list) and len(value) > 0:
103
- paraphrases = value
104
- break
105
- else:
106
- logger.warning("Could not extract paraphrases from response: %s", content)
107
- raise ValueError(f"Could not extract paraphrases from response: {content}")
108
-
109
- # Ensure we have the right number of paraphrases
110
- paraphrases = paraphrases[:n_paraphrases]
111
-
112
- # Add the original query as the first item
113
- all_queries = [query] + paraphrases
114
-
115
- elapsed_time = time.time() - start_time
116
- logger.info("Generated %d paraphrases in %.2f seconds", len(paraphrases), elapsed_time)
117
- for i, p in enumerate(paraphrases, 1):
118
- logger.info("Paraphrase %d: %s", i, p)
119
-
120
- return all_queries
121
-
122
- except Exception as e:
123
- logger.error("Error generating paraphrases: %s", str(e), exc_info=True)
124
- # Return original plus simple paraphrases as fallback
125
- fallback_paraphrases = [
126
- query,
127
- f"Could you tell me about {query.strip('?')}?",
128
- f"I'd like to know: {query}",
129
- f"Please provide information on {query.strip('?')}."
130
- ][:n_paraphrases+1]
131
-
132
- logger.info("Using fallback paraphrases due to error")
133
- for i, p in enumerate(fallback_paraphrases[1:], 1):
134
- logger.info("Fallback paraphrase %d: %s", i, p)
135
-
136
- return fallback_paraphrases
137
-
138
- def _get_single_response(self, query: str, index: int = None) -> str:
139
- """Get a single response from Mistral API for a query"""
140
- try:
141
- query_description = f"Query {index}: {query}" if index is not None else f"Query: {query}"
142
- logger.info("Getting response for %s", query_description)
143
- start_time = time.time()
144
-
145
- messages = [
146
- {
147
- "role": "system",
148
- "content": "You are a helpful AI assistant. Provide accurate, factual information in response to questions."
149
- },
150
- {
151
- "role": "user",
152
- "content": query
153
- }
154
- ]
155
-
156
- response = self.mistral_client.chat.complete(
157
- model=self.mistral_model,
158
- messages=messages
159
- )
160
-
161
- result = response.choices[0].message.content
162
- elapsed_time = time.time() - start_time
163
-
164
- logger.info("Received response for %s (%.2f seconds)", query_description, elapsed_time)
165
- logger.debug("Response content for %s: %s", query_description, result[:100] + "..." if len(result) > 100 else result)
166
-
167
- return result
168
-
169
- except Exception as e:
170
- error_msg = f"Error getting response for query '{query}': {e}"
171
- logger.error(error_msg, exc_info=True)
172
- return f"Error: Failed to get response for this query."
173
-
174
- def get_responses(self, queries: List[str]) -> List[str]:
175
- """Get responses from Mistral API for each query in parallel"""
176
- logger.info("Getting responses for %d queries in parallel", len(queries))
177
- start_time = time.time()
178
-
179
- # Use ThreadPoolExecutor for parallel API calls
180
- with ThreadPoolExecutor(max_workers=min(len(queries), 5)) as executor:
181
- # Submit tasks and map them to their original indices
182
- future_to_index = {
183
- executor.submit(self._get_single_response, query, i): i
184
- for i, query in enumerate(queries)
185
- }
186
-
187
- # Prepare a list with the correct length
188
- responses = [""] * len(queries)
189
-
190
- # Counter for completed responses
191
- completed_count = 0
192
-
193
- # Collect results as they complete
194
- for future in concurrent.futures.as_completed(future_to_index):
195
- index = future_to_index[future]
196
- try:
197
- responses[index] = future.result()
198
-
199
- # Update completion count and report progress
200
- completed_count += 1
201
- if self.progress_callback:
202
- self.progress_callback("responses_progress",
203
- completed_responses=completed_count,
204
- total_responses=len(queries))
205
-
206
- except Exception as e:
207
- logger.error("Error processing response for index %d: %s", index, str(e))
208
- responses[index] = f"Error: Failed to get response for query {index}."
209
-
210
- # Still update completion count even for errors
211
- completed_count += 1
212
- if self.progress_callback:
213
- self.progress_callback("responses_progress",
214
- completed_responses=completed_count,
215
- total_responses=len(queries))
216
-
217
- elapsed_time = time.time() - start_time
218
- logger.info("Received all %d responses in %.2f seconds total", len(responses), elapsed_time)
219
-
220
- return responses
221
-
222
- def detect_hallucination(self, query: str, n_paraphrases: int = 3) -> Dict:
223
- """
224
- Detect hallucinations by comparing responses to paraphrased queries using a judge model
225
-
226
- Returns:
227
- Dict containing hallucination judgment and all responses
228
- """
229
- logger.info("Starting hallucination detection for query: %s", query)
230
- start_time = time.time()
231
-
232
- # Report progress
233
- if self.progress_callback:
234
- self.progress_callback("starting", query=query)
235
-
236
- # Generate paraphrases
237
- logger.info("Step 1: Generating paraphrases")
238
- if self.progress_callback:
239
- self.progress_callback("generating_paraphrases", query=query)
240
-
241
- all_queries = self.generate_paraphrases(query, n_paraphrases)
242
-
243
- if self.progress_callback:
244
- self.progress_callback("paraphrases_complete", query=query, count=len(all_queries))
245
-
246
- # Get responses to all queries
247
- logger.info("Step 2: Getting responses to all %d queries", len(all_queries))
248
- if self.progress_callback:
249
- self.progress_callback("getting_responses", query=query, total=len(all_queries))
250
-
251
- all_responses = []
252
- for i, q in enumerate(all_queries):
253
- logger.info("Getting response %d/%d for query: %s", i+1, len(all_queries), q)
254
- if self.progress_callback:
255
- self.progress_callback("responses_progress", query=query, completed=i, total=len(all_queries))
256
-
257
- response = self._get_single_response(q, index=i)
258
- all_responses.append(response)
259
-
260
- if self.progress_callback:
261
- self.progress_callback("responses_complete", query=query)
262
-
263
- # Judge the responses for hallucinations
264
- logger.info("Step 3: Judging for hallucinations")
265
- if self.progress_callback:
266
- self.progress_callback("judging", query=query)
267
-
268
- # The first query is the original, rest are paraphrases
269
- original_query = all_queries[0]
270
- original_response = all_responses[0]
271
- paraphrased_queries = all_queries[1:] if len(all_queries) > 1 else []
272
- paraphrased_responses = all_responses[1:] if len(all_responses) > 1 else []
273
-
274
- # Judge the responses
275
- judgment = self.judge_hallucination(
276
- original_query=original_query,
277
- original_response=original_response,
278
- paraphrased_queries=paraphrased_queries,
279
- paraphrased_responses=paraphrased_responses
280
- )
281
-
282
- # Assemble the results
283
- results = {
284
- "original_query": original_query,
285
- "original_response": original_response,
286
- "paraphrased_queries": paraphrased_queries,
287
- "paraphrased_responses": paraphrased_responses,
288
- "hallucination_detected": judgment.hallucination_detected,
289
- "confidence_score": judgment.confidence_score,
290
- "conflicting_facts": judgment.conflicting_facts,
291
- "reasoning": judgment.reasoning,
292
- "summary": judgment.summary
293
- }
294
-
295
- # Report completion
296
- if self.progress_callback:
297
- self.progress_callback("complete", query=query)
298
-
299
- logger.info("Hallucination detection completed in %.2f seconds", time.time() - start_time)
300
- return results
301
-
302
- def judge_hallucination(self,
303
- original_query: str,
304
- original_response: str,
305
- paraphrased_queries: List[str],
306
- paraphrased_responses: List[str]) -> HallucinationJudgment:
307
- """
308
- Use OpenAI's o3-mini as a judge to detect hallucinations in the responses
309
- """
310
- logger.info("Judging hallucinations with OpenAI's %s model", self.openai_model)
311
- start_time = time.time()
312
-
313
- # Prepare the context for the judge
314
- context = f"""
315
- Original Question: {original_query}
316
-
317
- Original Response:
318
- {original_response}
319
-
320
- Paraphrased Questions and their Responses:
321
- """
322
-
323
- for i, (query, response) in enumerate(zip(paraphrased_queries, paraphrased_responses), 1):
324
- context += f"\nParaphrased Question {i}: {query}\n\nResponse {i}:\n{response}\n"
325
-
326
- system_prompt = """
327
- You are a judge evaluating whether an AI is hallucinating across different responses to semantically equivalent questions.
328
- Analyze all responses carefully to identify any factual inconsistencies or contradictions.
329
- Focus on factual discrepancies, not stylistic differences.
330
- A hallucination is when the AI states different facts in response to questions that are asking for the same information.
331
-
332
- Your response should be a JSON with the following fields:
333
- - hallucination_detected: boolean indicating whether hallucinations were found
334
- - confidence_score: number between 0 and 1 representing your confidence in the judgment
335
- - conflicting_facts: an array of objects describing any conflicting information found
336
- - reasoning: detailed explanation for your judgment
337
- - summary: a concise summary of your analysis
338
- """
339
-
340
- try:
341
- logger.info("Sending judgment request to OpenAI API...")
342
- response = self.openai_client.chat.completions.create(
343
- model=self.openai_model,
344
- messages=[
345
- {"role": "system", "content": system_prompt},
346
- {"role": "user", "content": f"Evaluate these responses for hallucinations:\n\n{context}"}
347
- ],
348
- response_format={"type": "json_object"}
349
- )
350
-
351
- result_json = json.loads(response.choices[0].message.content)
352
- logger.debug("Received judgment response: %s", result_json)
353
-
354
- # Create the HallucinationJudgment object from the JSON response
355
- judgment = HallucinationJudgment(
356
- hallucination_detected=result_json.get("hallucination_detected", False),
357
- confidence_score=result_json.get("confidence_score", 0.0),
358
- conflicting_facts=result_json.get("conflicting_facts", []),
359
- reasoning=result_json.get("reasoning", "No reasoning provided."),
360
- summary=result_json.get("summary", "No summary provided.")
361
- )
362
-
363
- elapsed_time = time.time() - start_time
364
- logger.info("Judgment completed in %.2f seconds", elapsed_time)
365
-
366
- return judgment
367
-
368
- except Exception as e:
369
- logger.error("Error in hallucination judgment: %s", str(e), exc_info=True)
370
- # Return a fallback judgment
371
- return HallucinationJudgment(
372
- hallucination_detected=False,
373
- confidence_score=0.0,
374
- conflicting_facts=[],
375
- reasoning="Failed to obtain judgment from the model.",
376
- summary="Analysis failed due to API error."
377
- )
378
-
379
-
380
- class HallucinationDetectorApp:
381
- def __init__(self):
382
- self.pas2 = None
383
- # Use the default HF Spaces persistent storage location
384
- self.data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
385
- self.db_path = os.path.join(self.data_dir, "feedback.db")
386
- logger.info("Initializing HallucinationDetectorApp")
387
- self._initialize_database()
388
- self.progress_callback = None
389
-
390
- def _initialize_database(self):
391
- """Initialize SQLite database for feedback storage in persistent directory"""
392
- try:
393
- # Create data directory if it doesn't exist
394
- os.makedirs(self.data_dir, exist_ok=True)
395
- logger.info(f"Ensuring data directory exists at {self.data_dir}")
396
-
397
- conn = sqlite3.connect(self.db_path)
398
- cursor = conn.cursor()
399
-
400
- # Create table if it doesn't exist
401
- cursor.execute('''
402
- CREATE TABLE IF NOT EXISTS feedback (
403
- id INTEGER PRIMARY KEY AUTOINCREMENT,
404
- timestamp TEXT,
405
- original_query TEXT,
406
- original_response TEXT,
407
- paraphrased_queries TEXT,
408
- paraphrased_responses TEXT,
409
- hallucination_detected INTEGER,
410
- confidence_score REAL,
411
- conflicting_facts TEXT,
412
- reasoning TEXT,
413
- summary TEXT,
414
- user_feedback TEXT
415
- )
416
- ''')
417
-
418
- conn.commit()
419
- conn.close()
420
- logger.info(f"Database initialized successfully at {self.db_path}")
421
- except Exception as e:
422
- logger.error(f"Error initializing database: {str(e)}", exc_info=True)
423
- # Fallback to temporary directory if /data is not accessible
424
- temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp_data")
425
- os.makedirs(temp_dir, exist_ok=True)
426
- self.db_path = os.path.join(temp_dir, "feedback.db")
427
- logger.warning(f"Using fallback database location: {self.db_path}")
428
-
429
- # Try creating database in fallback location
430
- try:
431
- conn = sqlite3.connect(self.db_path)
432
- cursor = conn.cursor()
433
- cursor.execute('''
434
- CREATE TABLE IF NOT EXISTS feedback (
435
- id INTEGER PRIMARY KEY AUTOINCREMENT,
436
- timestamp TEXT,
437
- original_query TEXT,
438
- original_response TEXT,
439
- paraphrased_queries TEXT,
440
- paraphrased_responses TEXT,
441
- hallucination_detected INTEGER,
442
- confidence_score REAL,
443
- conflicting_facts TEXT,
444
- reasoning TEXT,
445
- summary TEXT,
446
- user_feedback TEXT
447
- )
448
- ''')
449
- conn.commit()
450
- conn.close()
451
- logger.info(f"Database initialized in fallback location")
452
- except Exception as fallback_error:
453
- logger.error(f"Critical error: Could not initialize database in fallback location: {str(fallback_error)}", exc_info=True)
454
- raise
455
-
456
- def set_progress_callback(self, callback):
457
- """Set the progress callback function"""
458
- self.progress_callback = callback
459
-
460
- def initialize_api(self, mistral_api_key, openai_api_key):
461
- """Initialize the PAS2 with API keys"""
462
- try:
463
- logger.info("Initializing PAS2 with API keys")
464
- self.pas2 = PAS2(
465
- mistral_api_key=mistral_api_key,
466
- openai_api_key=openai_api_key,
467
- progress_callback=self.progress_callback
468
- )
469
- logger.info("API initialization successful")
470
- return "API keys set successfully! You can now use the application."
471
- except Exception as e:
472
- logger.error("Error initializing API: %s", str(e), exc_info=True)
473
- return f"Error initializing API: {str(e)}"
474
-
475
- def process_query(self, query: str):
476
- """Process the query using PAS2"""
477
- if not self.pas2:
478
- logger.error("PAS2 not initialized")
479
- return {
480
- "error": "Please set API keys first before processing queries."
481
- }
482
-
483
- if not query.strip():
484
- logger.warning("Empty query provided")
485
- return {
486
- "error": "Please enter a query."
487
- }
488
-
489
- try:
490
- # Set the progress callback if needed
491
- if self.progress_callback and self.pas2.progress_callback != self.progress_callback:
492
- self.pas2.progress_callback = self.progress_callback
493
-
494
- # Process the query
495
- logger.info("Processing query with PAS2: %s", query)
496
- results = self.pas2.detect_hallucination(query)
497
- logger.info("Query processing completed successfully")
498
- return results
499
- except Exception as e:
500
- logger.error("Error processing query: %s", str(e), exc_info=True)
501
- return {
502
- "error": f"Error processing query: {str(e)}"
503
- }
504
-
505
- def save_feedback(self, results, feedback):
506
- """Save results and user feedback to SQLite database"""
507
- try:
508
- logger.info("Saving user feedback: %s", feedback)
509
-
510
- conn = sqlite3.connect(self.db_path)
511
- cursor = conn.cursor()
512
-
513
- # Prepare data
514
- data = (
515
- datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
516
- results.get('original_query', ''),
517
- results.get('original_response', ''),
518
- str(results.get('paraphrased_queries', [])),
519
- str(results.get('paraphrased_responses', [])),
520
- 1 if results.get('hallucination_detected', False) else 0,
521
- results.get('confidence_score', 0.0),
522
- str(results.get('conflicting_facts', [])),
523
- results.get('reasoning', ''),
524
- results.get('summary', ''),
525
- feedback
526
- )
527
-
528
- # Insert data
529
- cursor.execute('''
530
- INSERT INTO feedback (
531
- timestamp, original_query, original_response,
532
- paraphrased_queries, paraphrased_responses,
533
- hallucination_detected, confidence_score,
534
- conflicting_facts, reasoning, summary, user_feedback
535
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
536
- ''', data)
537
-
538
- conn.commit()
539
- conn.close()
540
-
541
- logger.info("Feedback saved successfully to database")
542
- return "Feedback saved successfully!"
543
- except Exception as e:
544
- logger.error("Error saving feedback: %s", str(e), exc_info=True)
545
- return f"Error saving feedback: {str(e)}"
546
-
547
- def get_feedback_stats(self):
548
- """Get statistics about collected feedback"""
549
- try:
550
- conn = sqlite3.connect(self.db_path)
551
- cursor = conn.cursor()
552
-
553
- # Get total feedback count
554
- cursor.execute("SELECT COUNT(*) FROM feedback")
555
- total_count = cursor.fetchone()[0]
556
-
557
- # Get hallucination detection stats
558
- cursor.execute("""
559
- SELECT hallucination_detected, COUNT(*)
560
- FROM feedback
561
- GROUP BY hallucination_detected
562
- """)
563
- detection_stats = dict(cursor.fetchall())
564
-
565
- # Get average confidence score
566
- cursor.execute("SELECT AVG(confidence_score) FROM feedback")
567
- avg_confidence = cursor.fetchone()[0] or 0
568
-
569
- conn.close()
570
-
571
- return {
572
- "total_feedback": total_count,
573
- "hallucinations_detected": detection_stats.get(1, 0),
574
- "no_hallucinations": detection_stats.get(0, 0),
575
- "average_confidence": round(avg_confidence, 2)
576
- }
577
- except Exception as e:
578
- logger.error("Error getting feedback stats: %s", str(e), exc_info=True)
579
- return None
580
-
581
-
582
- # Progress tracking for UI updates
583
- class ProgressTracker:
584
- """Tracks progress of hallucination detection for UI updates"""
585
-
586
- STAGES = {
587
- "idle": {"status": "Ready", "progress": 0, "color": "#757575"},
588
- "starting": {"status": "Starting process...", "progress": 5, "color": "#2196F3"},
589
- "generating_paraphrases": {"status": "Generating paraphrases...", "progress": 15, "color": "#2196F3"},
590
- "paraphrases_complete": {"status": "Paraphrases generated", "progress": 30, "color": "#2196F3"},
591
- "getting_responses": {"status": "Getting responses (0/0)...", "progress": 35, "color": "#2196F3"},
592
- "responses_progress": {"status": "Getting responses ({completed}/{total})...", "progress": 40, "color": "#2196F3"},
593
- "responses_complete": {"status": "All responses received", "progress": 65, "color": "#2196F3"},
594
- "judging": {"status": "Analyzing responses for hallucinations...", "progress": 70, "color": "#2196F3"},
595
- "complete": {"status": "Analysis complete!", "progress": 100, "color": "#4CAF50"},
596
- "error": {"status": "Error: {error_message}", "progress": 100, "color": "#F44336"}
597
- }
598
-
599
- def __init__(self):
600
- self.stage = "idle"
601
- self.stage_data = self.STAGES[self.stage].copy()
602
- self.query = ""
603
- self.completed_responses = 0
604
- self.total_responses = 0
605
- self.error_message = ""
606
- self._lock = threading.Lock()
607
- self._status_callback = None
608
- self._stop_event = threading.Event()
609
- self._update_thread = None
610
-
611
- def register_callback(self, callback_fn):
612
- """Register callback function to update UI"""
613
- self._status_callback = callback_fn
614
-
615
- def update_stage(self, stage, **kwargs):
616
- """Update the current stage and trigger callback"""
617
- with self._lock:
618
- if stage in self.STAGES:
619
- self.stage = stage
620
- self.stage_data = self.STAGES[stage].copy()
621
-
622
- # Update with any additional parameters
623
- for key, value in kwargs.items():
624
- if key == 'query':
625
- self.query = value
626
- elif key == 'completed_responses':
627
- self.completed_responses = value
628
- elif key == 'total_responses':
629
- self.total_responses = value
630
- elif key == 'error_message':
631
- self.error_message = value
632
-
633
- # Format status message
634
- if stage == 'responses_progress':
635
- self.stage_data['status'] = self.stage_data['status'].format(
636
- completed=self.completed_responses,
637
- total=self.total_responses
638
- )
639
- elif stage == 'error':
640
- self.stage_data['status'] = self.stage_data['status'].format(
641
- error_message=self.error_message
642
- )
643
-
644
- if self._status_callback:
645
- self._status_callback(self.get_html_status())
646
-
647
- def get_html_status(self):
648
- """Get HTML representation of current status"""
649
- progress_width = f"{self.stage_data['progress']}%"
650
- status_text = self.stage_data['status']
651
- color = self.stage_data['color']
652
-
653
- query_info = f'<div class="query-display">{self.query}</div>' if self.query else ''
654
-
655
- # Only show status text if not in idle state
656
- status_display = f'<div class="progress-status" style="color: {color};">{status_text}</div>' if self.stage != "idle" else ''
657
-
658
- html = f"""
659
- <div class="progress-container">
660
- {query_info}
661
- {status_display}
662
- <div class="progress-bar-container">
663
- <div class="progress-bar" style="width: {progress_width}; background-color: {color};"></div>
664
- </div>
665
- </div>
666
- """
667
- return html
668
-
669
- def start_pulsing(self):
670
- """Start a pulsing animation for the progress bar during long operations"""
671
- if self._update_thread and self._update_thread.is_alive():
672
- return
673
-
674
- self._stop_event.clear()
675
- self._update_thread = threading.Thread(target=self._pulse_progress)
676
- self._update_thread.daemon = True
677
- self._update_thread.start()
678
-
679
- def stop_pulsing(self):
680
- """Stop the pulsing animation"""
681
- self._stop_event.set()
682
- if self._update_thread:
683
- self._update_thread.join(0.5)
684
-
685
- def _pulse_progress(self):
686
- """Animate the progress bar to show activity"""
687
- pulse_stages = ["⋯", "⋯⋯", "⋯⋯⋯", "⋯⋯", "⋯"]
688
- i = 0
689
- while not self._stop_event.is_set():
690
- with self._lock:
691
- if self.stage not in ["idle", "complete", "error"]:
692
- status_base = self.stage_data['status'].split("...")[0] if "..." in self.stage_data['status'] else self.stage_data['status']
693
- self.stage_data['status'] = f"{status_base}... {pulse_stages[i]}"
694
-
695
- if self._status_callback:
696
- self._status_callback(self.get_html_status())
697
-
698
- i = (i + 1) % len(pulse_stages)
699
- time.sleep(0.3)
700
-
701
-
702
- def create_interface():
703
- """Create Gradio interface"""
704
- detector = HallucinationDetectorApp()
705
-
706
- # Initialize Progress Tracker
707
- progress_tracker = ProgressTracker()
708
-
709
- # Initialize APIs from environment variables automatically
710
- try:
711
- detector.initialize_api(
712
- mistral_api_key=os.environ.get("HF_MISTRAL_API_KEY"),
713
- openai_api_key=os.environ.get("HF_OPENAI_API_KEY")
714
- )
715
- except Exception as e:
716
- print(f"Warning: Failed to initialize APIs from environment variables: {e}")
717
- print("Please make sure HF_MISTRAL_API_KEY and HF_OPENAI_API_KEY are set in your environment")
718
-
719
- # CSS for styling
720
- css = """
721
- .container {
722
- max-width: 1000px;
723
- margin: 0 auto;
724
- }
725
- .title {
726
- text-align: center;
727
- margin-bottom: 0.5em;
728
- color: #1a237e;
729
- font-weight: 600;
730
- }
731
- .subtitle {
732
- text-align: center;
733
- margin-bottom: 1.5em;
734
- color: #455a64;
735
- font-size: 1.2em;
736
- }
737
- .section-title {
738
- margin-top: 1em;
739
- margin-bottom: 0.5em;
740
- font-weight: bold;
741
- color: #283593;
742
- }
743
- .info-box {
744
- padding: 1.2em;
745
- border-radius: 8px;
746
- background-color: #f5f5f5;
747
- margin-bottom: 1em;
748
- box-shadow: 0 2px 5px rgba(0,0,0,0.05);
749
- }
750
- .hallucination-positive {
751
- padding: 1.2em;
752
- border-radius: 8px;
753
- background-color: #ffebee;
754
- border-left: 5px solid #f44336;
755
- margin-bottom: 1em;
756
- box-shadow: 0 2px 5px rgba(0,0,0,0.05);
757
- }
758
- .hallucination-negative {
759
- padding: 1.2em;
760
- border-radius: 8px;
761
- background-color: #e8f5e9;
762
- border-left: 5px solid #4caf50;
763
- margin-bottom: 1em;
764
- box-shadow: 0 2px 5px rgba(0,0,0,0.05);
765
- }
766
- .response-box {
767
- padding: 1.2em;
768
- border-radius: 8px;
769
- background-color: #f5f5f5;
770
- margin-bottom: 0.8em;
771
- box-shadow: 0 2px 5px rgba(0,0,0,0.05);
772
- }
773
- .example-queries {
774
- display: flex;
775
- flex-wrap: wrap;
776
- gap: 8px;
777
- margin-bottom: 15px;
778
- }
779
- .example-query {
780
- background-color: #e3f2fd;
781
- padding: 8px 15px;
782
- border-radius: 18px;
783
- font-size: 0.9em;
784
- cursor: pointer;
785
- transition: all 0.2s;
786
- border: 1px solid #bbdefb;
787
- }
788
- .example-query:hover {
789
- background-color: #bbdefb;
790
- box-shadow: 0 2px 5px rgba(0,0,0,0.1);
791
- }
792
- .stats-section {
793
- display: flex;
794
- justify-content: space-between;
795
- background-color: #e8eaf6;
796
- padding: 15px;
797
- border-radius: 8px;
798
- margin-bottom: 20px;
799
- }
800
- .stat-item {
801
- text-align: center;
802
- padding: 10px;
803
- }
804
- .stat-value {
805
- font-size: 1.5em;
806
- font-weight: bold;
807
- color: #303f9f;
808
- }
809
- .stat-label {
810
- font-size: 0.9em;
811
- color: #5c6bc0;
812
- }
813
- .feedback-section {
814
- border-top: 1px solid #e0e0e0;
815
- padding-top: 15px;
816
- margin-top: 20px;
817
- }
818
- footer {
819
- text-align: center;
820
- padding: 20px;
821
- margin-top: 30px;
822
- color: #9e9e9e;
823
- font-size: 0.9em;
824
- }
825
- .processing-status {
826
- padding: 12px;
827
- background-color: #fff3e0;
828
- border-left: 4px solid #ff9800;
829
- margin-bottom: 15px;
830
- font-weight: 500;
831
- color: #e65100;
832
- }
833
- .debug-panel {
834
- background-color: #f5f5f5;
835
- border: 1px solid #e0e0e0;
836
- border-radius: 4px;
837
- padding: 10px;
838
- margin-top: 15px;
839
- font-family: monospace;
840
- font-size: 0.9em;
841
- white-space: pre-wrap;
842
- max-height: 200px;
843
- overflow-y: auto;
844
- }
845
- .progress-container {
846
- padding: 15px;
847
- background-color: #fff;
848
- border-radius: 8px;
849
- box-shadow: 0 2px 5px rgba(0,0,0,0.05);
850
- margin-bottom: 15px;
851
- }
852
- .progress-status {
853
- font-weight: 500;
854
- margin-bottom: 8px;
855
- padding: 4px 0;
856
- font-size: 0.95em;
857
- }
858
- .progress-bar-container {
859
- background-color: #e0e0e0;
860
- height: 10px;
861
- border-radius: 5px;
862
- overflow: hidden;
863
- margin-bottom: 10px;
864
- box-shadow: inset 0 1px 3px rgba(0,0,0,0.1);
865
- }
866
- .progress-bar {
867
- height: 100%;
868
- transition: width 0.5s ease;
869
- background-image: linear-gradient(to right, #2196F3, #3f51b5);
870
- }
871
- .query-display {
872
- font-style: italic;
873
- color: #666;
874
- margin-bottom: 10px;
875
- background-color: #f5f5f5;
876
- padding: 8px;
877
- border-radius: 4px;
878
- border-left: 3px solid #2196F3;
879
- }
880
- """
881
-
882
- # Example queries
883
- example_queries = [
884
- "Who was the first person to land on the moon?",
885
- "What is the capital of France?",
886
- "How many planets are in our solar system?",
887
- "Who wrote the novel 1984?",
888
- "What is the speed of light?",
889
- "What was the first computer?"
890
- ]
891
-
892
- # Function to update the progress display
893
- def update_progress_display(html):
894
- """Update the progress display with the provided HTML"""
895
- return gr.update(visible=True, value=html)
896
-
897
- # Register the callback with the tracker
898
- progress_tracker.register_callback(update_progress_display)
899
-
900
- # Register the tracker with the detector
901
- detector.set_progress_callback(progress_tracker.update_stage)
902
-
903
- # Helper function to set example query
904
- def set_example_query(example):
905
- return example
906
-
907
- # Function to show processing is starting
908
- def start_processing(query):
909
- logger.info("Processing query: %s", query)
910
- # Stop any existing pulsing to prepare for incremental progress updates
911
- progress_tracker.stop_pulsing()
912
-
913
- # Reset to a processing state without the "Ready" text
914
- # Use "starting" stage but with minimal UI display
915
- progress_tracker.stage = "starting"
916
- progress_tracker.query = query
917
-
918
- # Force UI update with clean display
919
- if progress_tracker._status_callback:
920
- progress_tracker._status_callback(progress_tracker.get_html_status())
921
-
922
- return [
923
- gr.update(visible=True), # Show the progress display
924
- gr.update(visible=False), # Hide the results accordion
925
- gr.update(visible=False), # Hide the feedback accordion
926
- None # Reset hidden results
927
- ]
928
-
929
- # Main processing function
930
- def process_query_and_display_results(query, progress=gr.Progress()):
931
- if not query.strip():
932
- logger.warning("Empty query submitted")
933
- progress_tracker.stop_pulsing()
934
- progress_tracker.update_stage("error", error_message="Please enter a query.")
935
- return [
936
- gr.update(visible=True), # Show the progress with error
937
- gr.update(visible=False),
938
- gr.update(visible=False),
939
- None
940
- ]
941
-
942
- # Check if API is initialized
943
- if not detector.pas2:
944
- try:
945
- # Try to initialize from environment variables
946
- logger.info("Initializing APIs from environment variables")
947
- progress(0.05, desc="Initializing API...")
948
- init_message = detector.initialize_api(
949
- mistral_api_key=os.environ.get("HF_MISTRAL_API_KEY"),
950
- openai_api_key=os.environ.get("HF_OPENAI_API_KEY")
951
- )
952
- if "successfully" not in init_message:
953
- logger.error("Failed to initialize APIs: %s", init_message)
954
- progress_tracker.stop_pulsing()
955
- progress_tracker.update_stage("error", error_message="API keys not found in environment variables.")
956
- return [
957
- gr.update(visible=True),
958
- gr.update(visible=False),
959
- gr.update(visible=False),
960
- None
961
- ]
962
- except Exception as e:
963
- logger.error("Error initializing API: %s", str(e), exc_info=True)
964
- progress_tracker.stop_pulsing()
965
- progress_tracker.update_stage("error", error_message=f"Error initializing API: {str(e)}")
966
- return [
967
- gr.update(visible=True),
968
- gr.update(visible=False),
969
- gr.update(visible=False),
970
- None
971
- ]
972
-
973
- try:
974
- # Process the query
975
- logger.info("Starting hallucination detection process")
976
- start_time = time.time()
977
-
978
- # Set up a custom progress callback that uses both the progress_tracker and the gr.Progress
979
- def combined_progress_callback(stage, **kwargs):
980
- # Skip the idle stage, which shows "Ready"
981
- if stage == "idle":
982
- return
983
-
984
- progress_tracker.update_stage(stage, **kwargs)
985
-
986
- # Map the stages to progress values for the gr.Progress bar
987
- stage_to_progress = {
988
- "starting": 0.05,
989
- "generating_paraphrases": 0.15,
990
- "paraphrases_complete": 0.3,
991
- "getting_responses": 0.35,
992
- "responses_progress": lambda kwargs: 0.35 + (0.3 * (kwargs.get("completed", 0) / max(kwargs.get("total", 1), 1))),
993
- "responses_complete": 0.65,
994
- "judging": 0.7,
995
- "complete": 1.0,
996
- "error": 1.0
997
- }
998
-
999
- # Update the gr.Progress bar
1000
- if stage in stage_to_progress:
1001
- prog_value = stage_to_progress[stage]
1002
- if callable(prog_value):
1003
- prog_value = prog_value(kwargs)
1004
-
1005
- desc = progress_tracker.STAGES[stage]["status"]
1006
- if "{" in desc and "}" in desc:
1007
- # Format the description with any kwargs
1008
- desc = desc.format(**kwargs)
1009
-
1010
- # Ensure UI updates by adding a small delay
1011
- # This forces the progress updates to be rendered
1012
- progress(prog_value, desc=desc)
1013
-
1014
- # For certain key stages, add a small sleep to ensure progress is visible
1015
- if stage in ["starting", "generating_paraphrases", "paraphrases_complete",
1016
- "getting_responses", "responses_complete", "judging", "complete"]:
1017
- time.sleep(0.2) # Small delay to ensure UI update is visible
1018
-
1019
- # Use these steps for processing
1020
- detector.set_progress_callback(combined_progress_callback)
1021
-
1022
- # Create a wrapper function for detect_hallucination that gives more control over progress updates
1023
- def run_detection_with_visible_progress():
1024
- # Step 1: Start
1025
- combined_progress_callback("starting", query=query)
1026
- time.sleep(0.3) # Ensure starting status is visible
1027
-
1028
- # Step 2: Generate paraphrases (15-30%)
1029
- combined_progress_callback("generating_paraphrases", query=query)
1030
- all_queries = detector.pas2.generate_paraphrases(query)
1031
- combined_progress_callback("paraphrases_complete", query=query, count=len(all_queries))
1032
-
1033
- # Step 3: Get responses (35-65%)
1034
- combined_progress_callback("getting_responses", query=query, total=len(all_queries))
1035
- all_responses = []
1036
- for i, q in enumerate(all_queries):
1037
- # Show incremental progress for each response
1038
- combined_progress_callback("responses_progress", query=query, completed=i, total=len(all_queries))
1039
- response = detector.pas2._get_single_response(q, index=i)
1040
- all_responses.append(response)
1041
- combined_progress_callback("responses_complete", query=query)
1042
-
1043
- # Step 4: Judge hallucinations (70-100%)
1044
- combined_progress_callback("judging", query=query)
1045
-
1046
- # The first query is the original, rest are paraphrases
1047
- original_query = all_queries[0]
1048
- original_response = all_responses[0]
1049
- paraphrased_queries = all_queries[1:] if len(all_queries) > 1 else []
1050
- paraphrased_responses = all_responses[1:] if len(all_responses) > 1 else []
1051
-
1052
- # Judge the responses
1053
- judgment = detector.pas2.judge_hallucination(
1054
- original_query=original_query,
1055
- original_response=original_response,
1056
- paraphrased_queries=paraphrased_queries,
1057
- paraphrased_responses=paraphrased_responses
1058
- )
1059
-
1060
- # Assemble the results
1061
- results = {
1062
- "original_query": original_query,
1063
- "original_response": original_response,
1064
- "paraphrased_queries": paraphrased_queries,
1065
- "paraphrased_responses": paraphrased_responses,
1066
- "hallucination_detected": judgment.hallucination_detected,
1067
- "confidence_score": judgment.confidence_score,
1068
- "conflicting_facts": judgment.conflicting_facts,
1069
- "reasoning": judgment.reasoning,
1070
- "summary": judgment.summary
1071
- }
1072
-
1073
- # Show completion
1074
- combined_progress_callback("complete", query=query)
1075
- time.sleep(0.3) # Ensure complete status is visible
1076
-
1077
- return results
1078
-
1079
- # Run the detection process with visible progress
1080
- results = run_detection_with_visible_progress()
1081
-
1082
- # Calculate elapsed time
1083
- elapsed_time = time.time() - start_time
1084
- logger.info("Hallucination detection completed in %.2f seconds", elapsed_time)
1085
-
1086
- # Check for errors
1087
- if "error" in results:
1088
- logger.error("Error in results: %s", results["error"])
1089
- progress_tracker.stop_pulsing()
1090
- progress_tracker.update_stage("error", error_message=results["error"])
1091
- return [
1092
- gr.update(visible=True),
1093
- gr.update(visible=False),
1094
- gr.update(visible=False),
1095
- None
1096
- ]
1097
-
1098
- # Prepare responses for display
1099
- original_query = results["original_query"]
1100
- original_response = results["original_response"]
1101
-
1102
- paraphrased_queries = results["paraphrased_queries"]
1103
- paraphrased_responses = results["paraphrased_responses"]
1104
-
1105
- hallucination_detected = results["hallucination_detected"]
1106
- confidence = results["confidence_score"]
1107
- reasoning = results["reasoning"]
1108
- summary = results["summary"]
1109
-
1110
- # Format conflicting facts
1111
- conflicting_facts = results["conflicting_facts"]
1112
- conflicting_facts_text = ""
1113
- if conflicting_facts:
1114
- for i, fact in enumerate(conflicting_facts, 1):
1115
- conflicting_facts_text += f"{i}. "
1116
- if isinstance(fact, dict):
1117
- for key, value in fact.items():
1118
- conflicting_facts_text += f"{key}: {value}, "
1119
- conflicting_facts_text = conflicting_facts_text.rstrip(", ")
1120
- else:
1121
- conflicting_facts_text += str(fact)
1122
- conflicting_facts_text += "\n"
1123
-
1124
- # Format responses to escape any backslashes
1125
- original_response_safe = original_response.replace('\\', '\\\\').replace('\n', '<br>')
1126
- paraphrased_responses_safe = [r.replace('\\', '\\\\').replace('\n', '<br>') for r in paraphrased_responses]
1127
- reasoning_safe = reasoning.replace('\\', '\\\\').replace('\n', '<br>')
1128
- conflicting_facts_text_safe = conflicting_facts_text.replace('\\', '\\\\').replace('\n', '<br>') if conflicting_facts_text else "None identified"
1129
-
1130
- html_output = f"""
1131
- <div class="container">
1132
- <h2 class="title">Hallucination Detection Results</h2>
1133
-
1134
- <div class="stats-section">
1135
- <div class="stat-item">
1136
- <div class="stat-value">{'Yes' if hallucination_detected else 'No'}</div>
1137
- <div class="stat-label">Hallucination Detected</div>
1138
- </div>
1139
- <div class="stat-item">
1140
- <div class="stat-value">{confidence:.2f}</div>
1141
- <div class="stat-label">Confidence Score</div>
1142
- </div>
1143
- <div class="stat-item">
1144
- <div class="stat-value">{len(paraphrased_queries)}</div>
1145
- <div class="stat-label">Paraphrases Analyzed</div>
1146
- </div>
1147
- <div class="stat-item">
1148
- <div class="stat-value">{elapsed_time:.1f}s</div>
1149
- <div class="stat-label">Processing Time</div>
1150
- </div>
1151
- </div>
1152
-
1153
- <div class="{'hallucination-positive' if hallucination_detected else 'hallucination-negative'}">
1154
- <h3>Analysis Summary</h3>
1155
- <p>{summary}</p>
1156
- </div>
1157
-
1158
- <div class="section-title">Original Query</div>
1159
- <div class="response-box">
1160
- {original_query}
1161
- </div>
1162
-
1163
- <div class="section-title">Original Response</div>
1164
- <div class="response-box">
1165
- {original_response_safe}
1166
- </div>
1167
-
1168
- <div class="section-title">Paraphrased Queries and Responses</div>
1169
- """
1170
-
1171
- for i, (q, r) in enumerate(zip(paraphrased_queries, paraphrased_responses_safe), 1):
1172
- html_output += f"""
1173
- <div class="section-title">Paraphrased Query {i}</div>
1174
- <div class="response-box">
1175
- {q}
1176
- </div>
1177
-
1178
- <div class="section-title">Response {i}</div>
1179
- <div class="response-box">
1180
- {r}
1181
- </div>
1182
- """
1183
-
1184
- html_output += f"""
1185
- <div class="section-title">Detailed Analysis</div>
1186
- <div class="info-box">
1187
- <p><strong>Reasoning:</strong></p>
1188
- <p>{reasoning_safe}</p>
1189
-
1190
- <p><strong>Conflicting Facts:</strong></p>
1191
- <p>{conflicting_facts_text_safe}</p>
1192
- </div>
1193
- </div>
1194
- """
1195
-
1196
- logger.info("Updating UI with results")
1197
- progress_tracker.stop_pulsing()
1198
-
1199
- return [
1200
- gr.update(visible=False), # Hide progress display when showing results
1201
- gr.update(visible=True, value=html_output),
1202
- gr.update(visible=True),
1203
- results
1204
- ]
1205
-
1206
- except Exception as e:
1207
- logger.error("Error processing query: %s", str(e), exc_info=True)
1208
- progress_tracker.stop_pulsing()
1209
- progress_tracker.update_stage("error", error_message=f"Error processing query: {str(e)}")
1210
- return [
1211
- gr.update(visible=True),
1212
- gr.update(visible=False),
1213
- gr.update(visible=False),
1214
- None
1215
- ]
1216
-
1217
- # Helper function to submit feedback and update stats
1218
- def combine_feedback(fb_input, fb_text, results):
1219
- combined_feedback = f"{fb_input}: {fb_text}" if fb_text else fb_input
1220
- if not results:
1221
- return "No results to attach feedback to.", ""
1222
-
1223
- response = detector.save_feedback(results, combined_feedback)
1224
-
1225
- # Get updated stats
1226
- stats = detector.get_feedback_stats()
1227
- if stats:
1228
- stats_html = f"""
1229
- <div class="stats-section" style="margin-top: 15px;">
1230
- <div class="stat-item">
1231
- <div class="stat-value">{stats['total_feedback']}</div>
1232
- <div class="stat-label">Total Feedback</div>
1233
- </div>
1234
- <div class="stat-item">
1235
- <div class="stat-value">{stats['hallucinations_detected']}</div>
1236
- <div class="stat-label">Hallucinations Found</div>
1237
- </div>
1238
- <div class="stat-item">
1239
- <div class="stat-value">{stats['no_hallucinations']}</div>
1240
- <div class="stat-label">No Hallucinations</div>
1241
- </div>
1242
- <div class="stat-item">
1243
- <div class="stat-value">{stats['average_confidence']}</div>
1244
- <div class="stat-label">Avg. Confidence</div>
1245
- </div>
1246
- </div>
1247
- """
1248
- else:
1249
- stats_html = ""
1250
-
1251
- return response, stats_html
1252
-
1253
- # Create the interface
1254
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as interface:
1255
- gr.HTML(
1256
- """
1257
- <div style="text-align: center; margin-bottom: 1.5rem">
1258
- <h1 style="font-size: 2.2em; font-weight: 600; color: #1a237e; margin-bottom: 0.2em;">PAS2 - Hallucination Detector</h1>
1259
- <h3 style="font-size: 1.3em; color: #455a64; margin-bottom: 0.8em;">Advanced AI Response Verification Using Model-as-Judge</h3>
1260
- <p style="font-size: 1.1em; color: #546e7a; max-width: 800px; margin: 0 auto;">
1261
- This tool detects hallucinations in AI responses by comparing answers to semantically equivalent questions and using a specialized judge model.
1262
- </p>
1263
- </div>
1264
- """
1265
- )
1266
-
1267
- with gr.Accordion("About this Tool", open=False):
1268
- gr.Markdown(
1269
- """
1270
- ### How It Works
1271
-
1272
- This tool implements the Paraphrase-based Approach for Scrutinizing Systems (PAS2) with a model-as-judge enhancement:
1273
-
1274
- 1. **Paraphrase Generation**: Your question is paraphrased multiple ways while preserving its core meaning
1275
- 2. **Multiple Responses**: All questions (original + paraphrases) are sent to Mistral Large model
1276
- 3. **Expert Judgment**: OpenAI's o3-mini analyzes all responses to detect factual inconsistencies
1277
-
1278
- ### Why This Approach?
1279
-
1280
- When an AI hallucinates, it often provides different answers to the same question when phrased differently.
1281
- By using a separate judge model, we can identify these inconsistencies more effectively than with
1282
- metric-based approaches.
1283
-
1284
- ### Understanding the Results
1285
-
1286
- - **Confidence Score**: Indicates the judge's confidence in the hallucination detection
1287
- - **Conflicting Facts**: Specific inconsistencies found across responses
1288
- - **Reasoning**: The judge's detailed analysis explaining its decision
1289
-
1290
- ### Privacy Notice
1291
-
1292
- Your queries and the system's responses are saved to help improve hallucination detection.
1293
- No personally identifiable information is collected.
1294
- """
1295
- )
1296
-
1297
- with gr.Row():
1298
- with gr.Column():
1299
- # First define the query input
1300
- gr.Markdown("### Enter Your Question")
1301
- with gr.Row():
1302
- query_input = gr.Textbox(
1303
- label="",
1304
- placeholder="Ask a factual question (e.g., Who was the first person to land on the moon?)",
1305
- lines=3
1306
- )
1307
-
1308
- # Now define the example queries
1309
- gr.Markdown("### Or Try an Example")
1310
- example_row = gr.Row()
1311
- with example_row:
1312
- for example in example_queries:
1313
- example_btn = gr.Button(
1314
- example,
1315
- elem_classes=["example-query"],
1316
- scale=0
1317
- )
1318
- example_btn.click(
1319
- fn=set_example_query,
1320
- inputs=[gr.Textbox(value=example, visible=False)],
1321
- outputs=[query_input]
1322
- )
1323
-
1324
- with gr.Row():
1325
- submit_button = gr.Button("Detect Hallucinations", variant="primary", scale=1)
1326
-
1327
- # Error message
1328
- error_message = gr.HTML(
1329
- label="Status",
1330
- visible=False
1331
- )
1332
-
1333
- # Progress display
1334
- progress_display = gr.HTML(
1335
- value=progress_tracker.get_html_status(),
1336
- visible=True
1337
- )
1338
-
1339
- # Results display
1340
- results_accordion = gr.HTML(visible=False)
1341
-
1342
- # Add feedback stats display
1343
- feedback_stats = gr.HTML(visible=True)
1344
-
1345
- # Feedback section
1346
- with gr.Accordion("Provide Feedback", open=False, visible=False) as feedback_accordion:
1347
- gr.Markdown("### Help Improve the System")
1348
- gr.Markdown("Your feedback helps us refine the hallucination detection system.")
1349
-
1350
- feedback_input = gr.Radio(
1351
- label="Is the hallucination detection accurate?",
1352
- choices=["Yes, correct detection", "No, incorrectly flagged hallucination", "No, missed hallucination", "Unsure/Other"],
1353
- value="Yes, correct detection"
1354
- )
1355
-
1356
- feedback_text = gr.Textbox(
1357
- label="Additional comments (optional)",
1358
- placeholder="Please provide any additional observations or details...",
1359
- lines=2
1360
- )
1361
-
1362
- feedback_button = gr.Button("Submit Feedback", variant="secondary")
1363
- feedback_status = gr.Textbox(label="Feedback Status", interactive=False, visible=False)
1364
-
1365
- # Initialize feedback stats
1366
- initial_stats = detector.get_feedback_stats()
1367
- if initial_stats:
1368
- feedback_stats.value = f"""
1369
- <div class="stats-section">
1370
- <div class="stat-item">
1371
- <div class="stat-value">{initial_stats['total_feedback']}</div>
1372
- <div class="stat-label">Total Feedback</div>
1373
- </div>
1374
- <div class="stat-item">
1375
- <div class="stat-value">{initial_stats['hallucinations_detected']}</div>
1376
- <div class="stat-label">Hallucinations Found</div>
1377
- </div>
1378
- <div class="stat-item">
1379
- <div class="stat-value">{initial_stats['no_hallucinations']}</div>
1380
- <div class="stat-label">No Hallucinations</div>
1381
- </div>
1382
- <div class="stat-item">
1383
- <div class="stat-value">{initial_stats['average_confidence']}</div>
1384
- <div class="stat-label">Avg. Confidence</div>
1385
- </div>
1386
- </div>
1387
- """
1388
-
1389
- # Hidden state to store results for feedback
1390
- hidden_results = gr.State()
1391
-
1392
- # Set up event handlers
1393
- submit_button.click(
1394
- fn=start_processing,
1395
- inputs=[query_input],
1396
- outputs=[progress_display, results_accordion, feedback_accordion, hidden_results],
1397
- queue=False
1398
- ).then(
1399
- fn=process_query_and_display_results,
1400
- inputs=[query_input],
1401
- outputs=[progress_display, results_accordion, feedback_accordion, hidden_results]
1402
- )
1403
-
1404
- feedback_button.click(
1405
- fn=combine_feedback,
1406
- inputs=[feedback_input, feedback_text, hidden_results],
1407
- outputs=[feedback_status, feedback_stats]
1408
- )
1409
-
1410
- # Footer
1411
- gr.HTML(
1412
- """
1413
- <footer>
1414
- <p>Paraphrase-based Approach for Scrutinizing Systems (PAS2) - Advanced Hallucination Detection</p>
1415
- <p>Using Mistral Large for generation and OpenAI o3-mini as judge</p>
1416
- </footer>
1417
- """
1418
- )
1419
-
1420
- return interface
1421
-
1422
- # Add a test function to demonstrate progress bar in isolation
1423
- def test_progress():
1424
- """Simple test function to demonstrate progress bar"""
1425
- import gradio as gr
1426
- import time
1427
-
1428
- def slow_process(progress=gr.Progress()):
1429
- progress(0, desc="Starting process...")
1430
- time.sleep(0.5)
1431
-
1432
- # Phase 1: Generating paraphrases
1433
- progress(0.15, desc="Generating paraphrases...")
1434
- time.sleep(1)
1435
- progress(0.3, desc="Paraphrases generated")
1436
- time.sleep(0.5)
1437
-
1438
- # Phase 2: Getting responses
1439
- progress(0.35, desc="Getting responses...")
1440
- # Show incremental progress for responses
1441
- for i in range(3):
1442
- time.sleep(0.8)
1443
- prog = 0.35 + (0.3 * ((i+1) / 3))
1444
- progress(prog, desc=f"Getting responses ({i+1}/3)...")
1445
-
1446
- progress(0.65, desc="All responses received")
1447
- time.sleep(0.5)
1448
-
1449
- # Phase 3: Analyzing
1450
- progress(0.7, desc="Analyzing responses for hallucinations...")
1451
- time.sleep(2)
1452
-
1453
- # Complete
1454
- progress(1.0, desc="Analysis complete!")
1455
- return "Process completed successfully!"
1456
-
1457
- with gr.Blocks() as demo:
1458
- with gr.Row():
1459
- btn = gr.Button("Start Process")
1460
- output = gr.Textbox(label="Result")
1461
-
1462
- btn.click(fn=slow_process, outputs=output)
1463
-
1464
- demo.launch()
1465
-
1466
- # Main application entry point
1467
- if __name__ == "__main__":
1468
- logger.info("Starting PAS2 Hallucination Detector")
1469
- interface = create_interface()
1470
- logger.info("Launching Gradio interface...")
1471
- interface.launch(
1472
- server_name="0.0.0.0", # Bind to all interfaces
1473
- server_port=7860, # Default Hugging Face Spaces port
1474
- show_api=False,
1475
- quiet=True, # Changed to True for Hugging Face deployment
1476
- share=False,
1477
- max_threads=10,
1478
- debug=False # Changed to False for production deployment
1479
- )
1480
-
1481
- # Uncomment this line to run the test function instead of the main interface
1482
- # if __name__ == "__main__":
1483
- # test_progress()