askveracity / agent.py
ankanghosh's picture
Upload 7 files.
7130eb6 verified
raw
history blame
17.9 kB
"""
Agent module for the Fake News Detector application.
This module implements a LangGraph-based agent that orchestrates
the fact-checking process. It defines the agent setup, tools,
and processing pipeline for claim verification.
"""
import os
import time
import logging
import traceback
from langchain_core.tools import tool
from langchain.prompts import PromptTemplate
from langgraph.prebuilt import create_react_agent
from utils.models import get_llm_model
from utils.performance import PerformanceTracker
from modules.claim_extraction import extract_claims
from modules.evidence_retrieval import retrieve_combined_evidence
from modules.classification import classify_with_llm, aggregate_evidence
from modules.explanation import generate_explanation
# Configure logger
logger = logging.getLogger("misinformation_detector")
# Reference to global performance tracker
performance_tracker = PerformanceTracker()
# Define LangGraph Tools
@tool
def claim_extractor(query):
"""
Tool that extracts factual claims from a given text.
Args:
query (str): Text containing potential factual claims
Returns:
str: Extracted factual claim
"""
performance_tracker.log_claim_processed()
return extract_claims(query)
@tool
def evidence_retriever(query):
"""
Tool that retrieves evidence from multiple sources for a claim.
Args:
query (str): The factual claim to gather evidence for
Returns:
list: List of evidence items from various sources
"""
return retrieve_combined_evidence(query)
@tool
def truth_classifier(query, evidence):
"""
Tool that classifies the truthfulness of a claim based on evidence.
Args:
query (str): The factual claim to classify
evidence (list): Evidence items to evaluate
Returns:
str: JSON string containing verdict, confidence, and results
"""
classification_results = classify_with_llm(query, evidence)
truth_label, confidence = aggregate_evidence(classification_results)
# Debug logging
logger.info(f"Classification results: {len(classification_results)} items")
logger.info(f"Aggregate result: {truth_label}, confidence: {confidence}")
# Ensure confidence is at least 0.6 for any definitive verdict
if "True" in truth_label or "False" in truth_label:
confidence = max(confidence, 0.6)
# Return a dictionary with all needed information
result = {
"verdict": truth_label,
"confidence": confidence,
"results": classification_results
}
# Convert to string for consistent handling
import json
return json.dumps(result)
@tool
def explanation_generator(claim, evidence_results, truth_label):
"""
Tool that generates a human-readable explanation for the verdict.
Args:
claim (str): The factual claim being verified
evidence_results (list): Evidence items and classification results
truth_label (str): The verdict (True/False/Uncertain)
Returns:
str: Natural language explanation of the verdict
"""
explanation = generate_explanation(claim, evidence_results, truth_label)
logger.info(f"Generated explanation: {explanation[:100]}...")
return explanation
def setup_agent():
"""
Create and configure a ReAct agent with the fact-checking tools.
This function configures a LangGraph ReAct agent with all the
necessary tools for fact checking, including claim extraction,
evidence retrieval, classification, and explanation generation.
Returns:
object: Configured LangGraph agent ready for claim processing
Raises:
ValueError: If OpenAI API key is not set
"""
# Make sure OpenAI API key is set
if "OPENAI_API_KEY" not in os.environ or not os.environ["OPENAI_API_KEY"].strip():
logger.error("OPENAI_API_KEY environment variable not set or empty.")
raise ValueError("OpenAI API key is required")
# Define tools with any customizations
tools = [
claim_extractor,
evidence_retriever,
truth_classifier,
explanation_generator
]
# Define the prompt template with clearer, more efficient instructions
FORMAT_INSTRUCTIONS_TEMPLATE = """
Use the following format:
Question: the input question you must answer
Action: the action to take, should be one of: {tool_names}
Action Input: the input to the action
Observation: the result of the action
... (this Action/Action Input/Observation can repeat N times)
Final Answer: the final answer to the original input question
"""
prompt = PromptTemplate(
input_variables=["input", "tool_names"],
template=f"""
You are a fact-checking assistant that verifies claims by gathering evidence and
determining their truthfulness. Follow these exact steps in sequence:
1. Call claim_extractor to extract the main factual claim
2. Call evidence_retriever to gather evidence about the claim
3. Call truth_classifier to evaluate the claim using the evidence
4. Call explanation_generator to explain the result
5. Provide your Final Answer that summarizes everything
Execute these steps in order without unnecessary thinking steps between tool calls.
Be direct and efficient in your verification process.
{FORMAT_INSTRUCTIONS_TEMPLATE}
"""
)
try:
# Get the LLM model
model = get_llm_model()
# Create the agent with a shorter timeout
graph = create_react_agent(model, tools=tools)
logger.info("Agent created successfully")
return graph
except Exception as e:
logger.error(f"Error creating agent: {str(e)}")
raise e
def process_claim(claim, agent=None, recursion_limit=20):
"""
Process a claim to determine its truthfulness using the agent.
This function invokes the LangGraph agent to process a factual claim,
extract supporting evidence, evaluate the claim's truthfulness, and
generate a human-readable explanation.
Args:
claim (str): The factual claim to be verified
agent (object, optional): Initialized LangGraph agent. If None, an error is logged.
recursion_limit (int, optional): Maximum recursion depth for agent. Default: 20.
Higher values allow more complex reasoning but increase processing time.
Returns:
dict: Result dictionary containing:
- claim: Extracted factual claim
- evidence: List of evidence pieces
- evidence_count: Number of evidence pieces
- classification: Verdict (True/False/Uncertain)
- confidence: Confidence score (0-1)
- explanation: Human-readable explanation of the verdict
- final_answer: Final answer from the agent
- Or error information if processing failed
"""
if agent is None:
logger.error("Agent not initialized. Call setup_agent() first.")
return None
start_time = time.time()
logger.info(f"Processing claim with agent: {claim}")
try:
# Format inputs for the agent
inputs = {"messages": [("user", claim)]}
# Set configuration - reduced recursion limit for faster processing
config = {"recursion_limit": recursion_limit}
# Invoke the agent
response = agent.invoke(inputs, config)
# Format the response
result = format_response(response)
# Log performance
elapsed = time.time() - start_time
logger.info(f"Claim processed in {elapsed:.2f} seconds")
return result
except Exception as e:
logger.error(f"Error processing claim with agent: {str(e)}")
logger.error(traceback.format_exc())
return {"error": str(e)}
def format_response(response):
"""
Format the agent's response into a structured result.
This function extracts key information from the agent's response,
including the claim, evidence, classification, and explanation.
It also performs error handling and provides fallback values.
Args:
response (dict): Raw response from the LangGraph agent
Returns:
dict: Structured result containing claim verification data
"""
try:
if not response or "messages" not in response:
return {"error": "Invalid response format"}
messages = response.get("messages", [])
# Initialize result container with default values
result = {
"claim": None,
"evidence": [],
"evidence_count": 0,
"classification": "Uncertain",
"confidence": 0.2, # Default low confidence
"explanation": "Insufficient evidence to evaluate this claim.",
"final_answer": None,
"thoughts": []
}
# Track if we found results from each tool
found_tools = {
"claim_extractor": False,
"evidence_retriever": False,
"truth_classifier": False,
"explanation_generator": False
}
# Extract information from messages
tool_outputs = {}
for idx, message in enumerate(messages):
# Extract agent thoughts
if hasattr(message, "content") and getattr(message, "type", "") == "assistant":
content = message.content
if "Thought:" in content:
thought_parts = content.split("Thought:", 1)
if len(thought_parts) > 1:
thought = thought_parts[1].split("\n")[0].strip()
result["thoughts"].append(thought)
# Extract tool outputs
if hasattr(message, "type") and message.type == "tool":
tool_name = getattr(message, "name", "unknown")
# Store tool outputs
tool_outputs[tool_name] = message.content
# Extract specific information
if tool_name == "claim_extractor":
found_tools["claim_extractor"] = True
if message.content:
result["claim"] = message.content
elif tool_name == "evidence_retriever":
found_tools["evidence_retriever"] = True
# Handle string representation of a list
if message.content:
if isinstance(message.content, list):
result["evidence"] = message.content
result["evidence_count"] = len(message.content)
elif isinstance(message.content, str) and message.content.startswith("[") and message.content.endswith("]"):
try:
import ast
parsed_content = ast.literal_eval(message.content)
if isinstance(parsed_content, list):
result["evidence"] = parsed_content
result["evidence_count"] = len(parsed_content)
else:
result["evidence"] = [message.content]
result["evidence_count"] = 1
except:
result["evidence"] = [message.content]
result["evidence_count"] = 1
else:
result["evidence"] = [message.content]
result["evidence_count"] = 1
logger.warning(f"Evidence retrieved is not a list: {type(message.content)}")
elif tool_name == "truth_classifier":
found_tools["truth_classifier"] = True
# Log the incoming content for debugging
logger.info(f"Truth classifier content type: {type(message.content)}")
logger.info(f"Truth classifier content: {message.content}")
# Handle JSON formatted result from truth_classifier
if isinstance(message.content, str):
try:
import json
# Parse the JSON string
parsed_content = json.loads(message.content)
# Extract the values from the parsed content
result["classification"] = parsed_content.get("verdict", "Uncertain")
result["confidence"] = float(parsed_content.get("confidence", 0.2))
result["classification_results"] = parsed_content.get("results", [])
logger.info(f"Extracted from JSON: verdict={result['classification']}, confidence={result['confidence']}")
except json.JSONDecodeError:
logger.warning(f"Could not parse truth classifier JSON: {message.content}")
except Exception as e:
logger.warning(f"Error extracting from truth classifier output: {e}")
else:
logger.warning(f"Unexpected truth_classifier content format: {message.content}")
elif tool_name == "explanation_generator":
found_tools["explanation_generator"] = True
if message.content:
result["explanation"] = message.content
logger.info(f"Found explanation from tool: {message.content[:100]}...")
# Get final answer from last message
elif idx == len(messages) - 1 and hasattr(message, "content"):
result["final_answer"] = message.content
# Log which tools weren't found
missing_tools = [tool for tool, found in found_tools.items() if not found]
if missing_tools:
logger.warning(f"Missing tool outputs in response: {', '.join(missing_tools)}")
# FALLBACK: If we have truth classification but explanation is missing, generate it now
if found_tools["truth_classifier"] and not found_tools["explanation_generator"]:
logger.info("Explanation generator was not called by the agent, using fallback explanation generation")
try:
# Get the necessary inputs for explanation generation
claim = result["claim"]
evidence = result["evidence"]
truth_label = result["classification"]
confidence_value = result["confidence"] # Pass the confidence value
classification_results = result.get("classification_results", [])
# Choose the best available evidence for explanation
explanation_evidence = classification_results if classification_results else evidence
# Generate explanation with confidence value
explanation = generate_explanation(claim, explanation_evidence, truth_label, confidence_value)
# Use the generated explanation
if explanation:
logger.info(f"Generated fallback explanation: {explanation[:100]}...")
result["explanation"] = explanation
except Exception as e:
logger.error(f"Error generating fallback explanation: {e}")
# Make sure evidence exists
if result["evidence_count"] > 0 and (not result["evidence"] or len(result["evidence"]) == 0):
logger.warning("Evidence count is non-zero but evidence list is empty. This is a data inconsistency.")
result["evidence_count"] = 0
# Add debug info about the final result
logger.info(f"Final classification: {result['classification']}, confidence: {result['confidence']}")
logger.info(f"Final explanation: {result['explanation'][:100]}...")
# Add performance metrics
result["performance"] = performance_tracker.get_summary()
# Memory management - limit the size of evidence and thoughts
# To keep memory usage reasonable for web deployment
if "evidence" in result and isinstance(result["evidence"], list):
limited_evidence = []
for ev in result["evidence"]:
if isinstance(ev, str) and len(ev) > 500:
limited_evidence.append(ev[:497] + "...")
else:
limited_evidence.append(ev)
result["evidence"] = limited_evidence
# Limit thoughts to conserve memory
if "thoughts" in result and len(result["thoughts"]) > 10:
result["thoughts"] = result["thoughts"][:10]
return result
except Exception as e:
logger.error(f"Error formatting agent response: {str(e)}")
logger.error(traceback.format_exc())
return {
"error": str(e),
"traceback": traceback.format_exc(),
"classification": "Error",
"confidence": 0.1,
"explanation": "An error occurred while processing this claim."
}