Spaces:
Runtime error
Runtime error
""" | |
Utility functions for query processing and rewriting. | |
""" | |
import time | |
import logging | |
from openai import OpenAI | |
from prompt_template import ( | |
Prompt_template_translation, | |
Prompt_template_relevance, | |
Prompt_template_autism_confidence, | |
Prompt_template_autism_rewriter, | |
Prompt_template_answer_autism_relevance | |
) | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize OpenAI client | |
DEEPINFRA_API_KEY = "285LUJulGIprqT6hcPhiXtcrphU04FG4" | |
openai = OpenAI( | |
api_key=DEEPINFRA_API_KEY, | |
base_url="https://api.deepinfra.com/v1/openai", | |
) | |
def call_llm(model: str, messages: list[dict], temperature: float = 0.0, timeout: int = 30, **kwargs) -> str: | |
"""Call the LLM with given messages and return the response.""" | |
try: | |
logger.info(f"Making API call to {model} with timeout {timeout}s") | |
start_time = time.time() | |
resp = openai.chat.completions.create( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
timeout=timeout, | |
**kwargs | |
) | |
elapsed = time.time() - start_time | |
logger.info(f"API call completed in {elapsed:.2f}s") | |
return resp.choices[0].message.content.strip() | |
except Exception as e: | |
logger.error(f"API call failed: {e}") | |
# Return fallback response | |
if "translation" in str(messages).lower(): | |
# For translation, return the original query | |
return messages[0]["content"].split("Query: ")[-1] if "Query: " in messages[0]["content"] else "Error" | |
else: | |
# For relevance, assume not related | |
return "0" | |
def enhanced_autism_relevance_check(query: str) -> dict: | |
""" | |
Enhanced autism relevance checking with detailed analysis. | |
Returns a dictionary with score, category, and reasoning. | |
""" | |
try: | |
logger.info(f"Enhanced autism relevance check for: '{query[:50]}...'") | |
# Use the enhanced confidence prompt | |
confidence_prompt = Prompt_template_autism_confidence.format(query=query) | |
response = call_llm( | |
model="Qwen/Qwen3-32B", | |
messages=[{"role": "user", "content": confidence_prompt}], | |
reasoning_effort="none", | |
timeout=15 | |
) | |
# Extract numeric score | |
confidence_score = 0 | |
try: | |
import re | |
numbers = re.findall(r'\d+', response) | |
if numbers: | |
confidence_score = int(numbers[0]) | |
confidence_score = max(0, min(100, confidence_score)) | |
except: | |
confidence_score = 0 | |
# Determine category and action based on enhanced scoring | |
if confidence_score >= 85: | |
category = "directly_autism_related" | |
action = "accept_as_is" | |
reasoning = "Directly mentions autism or autism-specific topics" | |
elif confidence_score >= 70: | |
category = "highly_autism_relevant" | |
action = "accept_as_is" | |
reasoning = "Core autism symptoms or characteristics" | |
elif confidence_score >= 55: | |
category = "significantly_autism_relevant" | |
action = "rewrite_for_autism" | |
reasoning = "Common comorbidity or autism-related issue" | |
elif confidence_score >= 40: | |
category = "moderately_autism_relevant" | |
action = "rewrite_for_autism" | |
reasoning = "Broader developmental or family concern related to autism" | |
elif confidence_score >= 25: | |
category = "somewhat_autism_relevant" | |
action = "conditional_rewrite" | |
reasoning = "General topic with potential autism applications" | |
else: | |
category = "not_autism_relevant" | |
action = "reject" | |
reasoning = "Not related to autism or autism care" | |
result = { | |
"score": confidence_score, | |
"category": category, | |
"action": action, | |
"reasoning": reasoning | |
} | |
logger.info(f"Enhanced relevance result: {result}") | |
return result | |
except Exception as e: | |
logger.error(f"Error in enhanced_autism_relevance_check: {e}") | |
return { | |
"score": 0, | |
"category": "error", | |
"action": "reject", | |
"reasoning": "Error during processing" | |
} | |
def check_autism_confidence(query: str) -> int: | |
""" | |
Check autism relevance confidence score (0-100). | |
Returns the confidence score as an integer. | |
""" | |
try: | |
logger.info(f"Checking autism confidence for query: '{query[:50]}...'") | |
confidence_prompt = Prompt_template_autism_confidence.format(query=query) | |
response = call_llm( | |
model="Qwen/Qwen3-32B", | |
messages=[{"role": "user", "content": confidence_prompt}], | |
reasoning_effort="none", | |
timeout=15 | |
) | |
# Extract numeric score from response | |
confidence_score = 0 | |
try: | |
# Try to extract number from response | |
import re | |
numbers = re.findall(r'\d+', response) | |
if numbers: | |
confidence_score = int(numbers[0]) | |
# Ensure it's within valid range | |
confidence_score = max(0, min(100, confidence_score)) | |
else: | |
logger.warning(f"No numeric score found in response: {response}") | |
confidence_score = 0 | |
except: | |
logger.error(f"Failed to parse confidence score from: {response}") | |
confidence_score = 0 | |
logger.info(f"Autism confidence score: {confidence_score}") | |
return confidence_score | |
except Exception as e: | |
logger.error(f"Error in check_autism_confidence: {e}") | |
return 0 | |
def rewrite_query_for_autism(query: str) -> str: | |
""" | |
Automatically rewrite a query to be autism-specific. | |
""" | |
try: | |
logger.info(f"Rewriting query for autism: '{query[:50]}...'") | |
rewrite_prompt = Prompt_template_autism_rewriter.format(query=query) | |
rewritten_query = call_llm( | |
model="Qwen/Qwen3-32B", | |
messages=[{"role": "user", "content": rewrite_prompt}], | |
reasoning_effort="none", | |
timeout=15 | |
) | |
if rewritten_query == "Error" or len(rewritten_query.strip()) == 0: | |
logger.warning("Rewriting failed, using fallback") | |
rewritten_query = f"How does autism relate to {query.lower()}?" | |
else: | |
rewritten_query = rewritten_query.strip() | |
logger.info(f"Query rewritten to: '{rewritten_query[:50]}...'") | |
return rewritten_query | |
except Exception as e: | |
logger.error(f"Error in rewrite_query_for_autism: {e}") | |
return f"How does autism relate to {query.lower()}?" | |
def check_answer_autism_relevance(answer: str) -> int: | |
""" | |
Check if an answer is sufficiently related to autism (0-100 score). | |
Used for document-based queries to filter non-autism answers. | |
""" | |
try: | |
logger.info(f"Checking answer autism relevance for: '{answer[:50]}...'") | |
relevance_prompt = Prompt_template_answer_autism_relevance.format(answer=answer) | |
response = call_llm( | |
model="Qwen/Qwen3-32B", | |
messages=[{"role": "user", "content": relevance_prompt}], | |
reasoning_effort="none", | |
timeout=15 | |
) | |
# Extract numeric score from response | |
relevance_score = 0 | |
try: | |
import re | |
numbers = re.findall(r'\d+', response) | |
if numbers: | |
relevance_score = int(numbers[0]) | |
relevance_score = max(0, min(100, relevance_score)) | |
else: | |
logger.warning(f"No numeric score found in response: {response}") | |
relevance_score = 0 | |
except: | |
logger.error(f"Failed to parse relevance score from: {response}") | |
relevance_score = 0 | |
logger.info(f"Answer autism relevance score: {relevance_score}") | |
return relevance_score | |
except Exception as e: | |
logger.error(f"Error in check_answer_autism_relevance: {e}") | |
return 0 | |
def process_query_for_rewrite(query: str) -> tuple[str, bool, str]: | |
""" | |
Enhanced query processing with sophisticated autism relevance detection. | |
NEW ENHANCED LOGIC: | |
1. Score 85-100 → Directly autism-related, use as-is | |
2. Score 70-84 → Highly autism-relevant (core symptoms), use as-is | |
3. Score 55-69 → Significantly autism-relevant (comorbidities), rewrite for autism | |
4. Score 40-54 → Moderately autism-relevant, rewrite for autism | |
5. Score 25-39 → Somewhat relevant, conditional rewrite (ask user or auto-rewrite) | |
6. Score 0-24 → Not autism-related, reject | |
Returns: (processed_query, is_autism_related, rewritten_query_if_needed) | |
""" | |
try: | |
logger.info(f"Processing query with enhanced confidence logic: '{query[:50]}...'") | |
start_time = time.time() | |
# Step 1: Translate and correct the query | |
logger.info("Step 1: Translating/correcting query") | |
corrected_query = call_llm( | |
model="Qwen/Qwen3-32B", | |
messages=[{"role": "user", "content": Prompt_template_translation.format(query=query)}], | |
reasoning_effort="none", | |
timeout=15 | |
) | |
if corrected_query == "Error": | |
logger.warning("Translation failed, using original query") | |
corrected_query = query | |
# Step 2: Get enhanced autism relevance analysis | |
logger.info("Step 2: Enhanced autism relevance checking") | |
relevance_result = enhanced_autism_relevance_check(corrected_query) | |
confidence_score = relevance_result["score"] | |
action = relevance_result["action"] | |
reasoning = relevance_result["reasoning"] | |
logger.info(f"Relevance analysis: {confidence_score}% - {reasoning}") | |
# Step 3: Take action based on enhanced analysis | |
if action == "accept_as_is": | |
logger.info(f"High relevance ({confidence_score}%) - accepting as-is: {reasoning}") | |
return corrected_query, True, "" | |
elif action == "rewrite_for_autism": | |
logger.info(f"Moderate relevance ({confidence_score}%) - rewriting for autism: {reasoning}") | |
rewritten_query = rewrite_query_for_autism(corrected_query) | |
return rewritten_query, True, "" | |
elif action == "conditional_rewrite": | |
# For somewhat relevant queries, automatically rewrite (could be enhanced with user confirmation) | |
logger.info(f"Low-moderate relevance ({confidence_score}%) - conditionally rewriting: {reasoning}") | |
rewritten_query = rewrite_query_for_autism(corrected_query) | |
return rewritten_query, True, "" | |
else: # action == "reject" | |
logger.info(f"Low relevance ({confidence_score}%) - rejecting: {reasoning}") | |
return corrected_query, False, "" | |
elapsed = time.time() - start_time | |
logger.info(f"Enhanced query processing completed in {elapsed:.2f}s") | |
except Exception as e: | |
logger.error(f"Error in process_query_for_rewrite: {e}") | |
# Fallback: return original query as not autism-related | |
return query, False, "" | |
def get_non_autism_response() -> str: | |
"""Return a more human-like response for non-autism queries.""" | |
return ("Hi there! I appreciate you reaching out to me. I'm Wisal, and I specialize specifically in autism and Autism Spectrum Disorders. " | |
"I noticed your question isn't quite related to autism topics. I'd love to help you, but I'm most effective when answering " | |
"questions about autism, ASD, autism support strategies, therapies, or related concerns.\n\n" | |
"Could you try asking me something about autism instead? I'm here and ready to help with any autism-related questions you might have! 😊") | |
def get_non_autism_answer_response() -> str: | |
"""Return a more human-like response when document answers are not autism-related.""" | |
return ("I'm sorry, but the information I found in the document doesn't seem to be related to autism or Autism Spectrum Disorders. " | |
"Since I'm Wisal, your autism specialist, I want to make sure I'm providing you with relevant, autism-focused information. " | |
"Could you try asking a question that's more specifically about autism? I'm here to help with any autism-related topics! 😊") |