|
import os |
|
import gc |
|
import time |
|
import asyncio |
|
import torch |
|
import uuid |
|
import rustworkx as rx |
|
import numpy as np |
|
from concurrent.futures import ThreadPoolExecutor |
|
from typing import List, Dict, Any |
|
from pyvis.network import Network |
|
from src.query_processing.late_chunking.late_chunker import LateChunker |
|
from src.query_processing.query_processor import QueryProcessor |
|
from src.reasoning.reasoner import Reasoner |
|
from src.utils.api_key_manager import APIKeyManager |
|
from src.search.search_engine import SearchEngine |
|
from src.crawl.crawler import CustomCrawler |
|
from sentence_transformers import SentenceTransformer |
|
from bert_score.scorer import BERTScorer |
|
from tenacity import RetryError |
|
from openai import RateLimitError |
|
from anthropic import RateLimitError as AnthropicRateLimitError |
|
from google.api_core.exceptions import ResourceExhausted |
|
|
|
class GraphRAG: |
|
def __init__(self, num_workers: int = 1): |
|
"""Initialize graph and required components.""" |
|
|
|
self.graphs = {} |
|
self.current_graph_id = None |
|
|
|
|
|
self.num_workers = num_workers |
|
self.search_engine = SearchEngine() |
|
self.query_processor = QueryProcessor() |
|
self.reasoner = Reasoner() |
|
|
|
self.custom_crawler = CustomCrawler(max_concurrent_requests=1000) |
|
self.chunking = LateChunker() |
|
self.llm = APIKeyManager().get_llm() |
|
|
|
|
|
self.model = SentenceTransformer( |
|
"dunzhang/stella_en_400M_v5", |
|
trust_remote_code=True, |
|
device="cuda" if torch.cuda.is_available() else "cpu" |
|
) |
|
self.scorer = BERTScorer( |
|
model_type="roberta-base", |
|
lang="en", |
|
rescale_with_baseline=True, |
|
device="cuda" if torch.cuda.is_available() else "cpu" |
|
) |
|
|
|
|
|
self.root_node_id = "QR" |
|
self.node_counter = 0 |
|
self.sub_node_counter = 0 |
|
self.cross_connections = set() |
|
|
|
|
|
self.semaphore = asyncio.Semaphore(min(num_workers * 2, 12)) |
|
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=self.num_workers) |
|
|
|
|
|
self.on_event_callback = None |
|
|
|
def set_on_event_callback(self, callback): |
|
"""Register a single callback to be triggered for various event types.""" |
|
self.on_event_callback = callback |
|
|
|
async def emit_event(self, event_type: str, data: dict): |
|
"""Helper method to safely emit an event if a callback is registered.""" |
|
if self.on_event_callback: |
|
if asyncio.iscoroutinefunction(self.on_event_callback): |
|
return await self.on_event_callback(event_type, data) |
|
else: |
|
return self.on_event_callback(event_type, data) |
|
|
|
def _get_current_graph_data(self): |
|
if self.current_graph_id is None or self.current_graph_id not in self.graphs: |
|
raise Exception("Error: No current graph selected") |
|
|
|
return self.graphs[self.current_graph_id] |
|
|
|
def add_node(self, node_id: str, query: str, data: str = "", role: str = None): |
|
"""Add a node to the current graph.""" |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
|
|
|
|
embedding = self.model.encode(query).tolist() |
|
node_data = { |
|
"id": node_id, |
|
"query": query, |
|
"data": data, |
|
"role": role, |
|
"embedding": embedding, |
|
"pagerank": 0, |
|
"graph_id": self.current_graph_id |
|
} |
|
node_index = graph.add_node(node_data) |
|
node_map[node_id] = node_index |
|
|
|
print(f"Added node '{node_id}' to graph '{self.current_graph_id}' with role '{role}' and query: '{query}'") |
|
|
|
def _has_path(self, source_idx: int, target_idx: int) -> bool: |
|
"""Helper method to check if there is a path from source to target in the current graph.""" |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
visited = set() |
|
stack = [source_idx] |
|
|
|
while stack: |
|
current = stack.pop() |
|
|
|
if current == target_idx: |
|
return True |
|
|
|
if current in visited: |
|
continue |
|
|
|
visited.add(current) |
|
for neighbor in graph.neighbors(current): |
|
stack.append(neighbor) |
|
|
|
return False |
|
|
|
def add_edge(self, node1: str, node2: str, weight: float = 1.0, relationship_type: str = None): |
|
"""Add an edge between two nodes in a way that preserves a DAG structure.""" |
|
if self.current_graph_id is None: |
|
raise Exception("Error: No current graph selected") |
|
|
|
if node1 == node2: |
|
print(f"Cannot add edge to the same node {node1}!") |
|
return |
|
|
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
|
|
if node1 not in node_map or node2 not in node_map: |
|
print(f"One or both nodes {node1}, {node2} do not exist in the current graph.") |
|
return |
|
|
|
idx1 = node_map[node1] |
|
idx2 = node_map[node2] |
|
|
|
|
|
if self._has_path(idx2, idx1): |
|
print(f"An edge between {node1} -> {node2} already exists or would create a cycle!") |
|
return |
|
|
|
if relationship_type and weight: |
|
edge_data = {"type": relationship_type, "weight": weight} |
|
graph.add_edge(idx1, idx2, edge_data) |
|
else: |
|
raise ValueError("Error: Relationship type and weight must be provided") |
|
print(f"Added edge between '{node1}' and '{node2}' in graph '{self.current_graph_id}' (type='{relationship_type}', weight={weight})") |
|
|
|
def edge_exists(self, node1: str, node2: str) -> bool: |
|
"""Check if an edge exists between two nodes.""" |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
|
|
if node1 not in node_map or node2 not in node_map: |
|
return False |
|
idx1 = node_map[node1] |
|
idx2 = node_map[node2] |
|
|
|
for edge in graph.out_edges(idx1): |
|
if edge[1] == idx2: |
|
return True |
|
|
|
return False |
|
|
|
def graph_exists(self) -> bool: |
|
"""Check if a graph exists.""" |
|
return self.current_graph_id is not None and self.current_graph_id in self.graphs and len(self.graphs[self.current_graph_id]["node_map"]) > 0 |
|
|
|
def get_graphs(self) -> list: |
|
"""Get detailed information about all existing graphs and their nodes.""" |
|
result = [] |
|
for graph_id, data in self.graphs.items(): |
|
metadata = data["metadata"] |
|
node_map = data["node_map"] |
|
graph = data["graph"] |
|
nodes_info = [] |
|
|
|
for node_id, idx in node_map.items(): |
|
node_data = graph.get_node_data(idx) |
|
nodes_info.append({ |
|
"id": node_data.get("id"), |
|
"query": node_data.get("query"), |
|
"data": node_data.get("data"), |
|
"role": node_data.get("role"), |
|
"pagerank": node_data.get("pagerank") |
|
}) |
|
edge_count = len(graph.edge_list()) |
|
result.append({ |
|
"graph_info": { |
|
"graph_id": graph_id, |
|
"created": metadata.get("created"), |
|
"updated": metadata.get("updated"), |
|
"node_count": len(node_map), |
|
"edge_count": edge_count, |
|
"nodes": nodes_info |
|
} |
|
}) |
|
|
|
result.sort(key=lambda x: x["graph_info"]["created"], reverse=True) |
|
return result |
|
|
|
def select_graph(self, graph_id: str) -> bool: |
|
"""Select a specific graph as the current working graph.""" |
|
if graph_id in self.graphs: |
|
self.current_graph_id = graph_id |
|
return True |
|
return False |
|
|
|
def create_new_graph(self) -> str: |
|
"""Create a new graph instance and its ID.""" |
|
graph_id = str(uuid.uuid4()) |
|
graph = rx.PyDiGraph() |
|
node_map = {} |
|
metadata = { |
|
"id": graph_id, |
|
"created": time.time(), |
|
"updated": time.time() |
|
} |
|
self.graphs[graph_id] = {"graph": graph, "node_map": node_map, "metadata": metadata} |
|
self.current_graph_id = graph_id |
|
|
|
return graph_id |
|
|
|
def load_graph(self, node_id: str) -> bool: |
|
"""Load an existing graph structure from memory based on a node ID.""" |
|
|
|
for gid, data in self.graphs.items(): |
|
if node_id in data["node_map"]: |
|
self.current_graph_id = gid |
|
|
|
for n_id in data["node_map"].keys(): |
|
if "SQ" in n_id: |
|
num = int(''.join(filter(str.isdigit, n_id)) or 0) |
|
self.node_counter = max(self.node_counter, num) |
|
elif "SSQ" in n_id: |
|
num = int(''.join(filter(str.isdigit, n_id)) or 0) |
|
self.sub_node_counter = max(self.sub_node_counter, num) |
|
|
|
self.node_counter += 1 |
|
self.sub_node_counter += 1 |
|
graph = data["graph"] |
|
node_map = data["node_map"] |
|
|
|
for (u, v), edge_data in zip(graph.edge_list(), graph.edges()): |
|
if edge_data.get("type") == "logical": |
|
source_id = graph.get_node_data(u).get("id") |
|
target_id = graph.get_node_data(v).get("id") |
|
connection = tuple(sorted([source_id, target_id])) |
|
self.cross_connections.add(connection) |
|
|
|
print(f"Successfully loaded graph. Current counters - Node: {self.node_counter}, Sub: {self.sub_node_counter}") |
|
return True |
|
|
|
print(f"Graph with node_id {node_id} not found.") |
|
|
|
return False |
|
|
|
async def modify_graph(self, new_query: str, similar_node_id: str, session_id: str = None): |
|
"""Modify an existing graph structure by integrating a new query.""" |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
|
|
async def add_as_sibling(node_id: str, query: str): |
|
if node_id not in node_map: |
|
raise ValueError(f"Node {node_id} not found") |
|
|
|
idx = node_map[node_id] |
|
in_edges = graph.in_edges(idx) |
|
|
|
if not in_edges: |
|
raise ValueError(f"No parent found for node {node_id}") |
|
|
|
parent_idx = in_edges[0][0] |
|
parent_data = graph.get_node_data(parent_idx) |
|
parent_id = parent_data.get("id") |
|
|
|
if "SQ" in node_id: |
|
self.node_counter += 1 |
|
new_node_id = f"SQ{self.node_counter}" |
|
else: |
|
self.sub_node_counter += 1 |
|
new_node_id = f"SSQ{self.sub_node_counter}" |
|
|
|
self.add_node(new_node_id, query, role="independent") |
|
self.add_edge(parent_id, new_node_id, relationship_type=in_edges[0][2].get("type")) |
|
|
|
return new_node_id |
|
|
|
async def add_as_child(node_id: str, query: str): |
|
if "SQ" in node_id: |
|
self.sub_node_counter += 1 |
|
new_node_id = f"SSQ{self.sub_node_counter}" |
|
else: |
|
self.node_counter += 1 |
|
new_node_id = f"SQ{self.node_counter}" |
|
|
|
self.add_node(new_node_id, query, role="dependent") |
|
self.add_edge(node_id, new_node_id, relationship_type="logical") |
|
|
|
return new_node_id |
|
|
|
def collect_graph_context() -> list: |
|
"""Collect context from existing graph nodes.""" |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
nodes = [] |
|
|
|
for n_id, idx in node_map.items(): |
|
if n_id == self.root_node_id: |
|
continue |
|
node_data = graph.get_node_data(idx) |
|
nodes.append({ |
|
"id": node_data.get("id"), |
|
"query": node_data.get("query"), |
|
"role": node_data.get("role") |
|
}) |
|
|
|
nodes.sort(key=lambda x: (0 if x["id"].startswith("SQ") else (1 if x["id"].startswith("SSQ") else 2), x["id"])) |
|
level_queries = {} |
|
current_sq = None |
|
|
|
for node in nodes: |
|
node_id = node["id"] |
|
if node_id.startswith("SQ"): |
|
current_sq = node_id |
|
|
|
if current_sq not in level_queries: |
|
level_queries[current_sq] = { |
|
"originalquery": node["query"], |
|
"subqueries": [] |
|
} |
|
level_queries[current_sq]["subqueries"].append({ |
|
"subquery": node["query"], |
|
"role": node["role"], |
|
"dependson": [] |
|
}) |
|
|
|
elif node_id.startswith("SSQ") and current_sq: |
|
level_queries[current_sq]["subqueries"].append({ |
|
"subquery": node["query"], |
|
"role": node["role"], |
|
"dependson": [] |
|
}) |
|
|
|
return list(level_queries.values()) |
|
|
|
if similar_node_id not in node_map: |
|
raise Exception(f"Node {similar_node_id} not found") |
|
|
|
similar_node_data = graph.get_node_data(node_map[similar_node_id]) |
|
has_parent = len(graph.in_edges(node_map[similar_node_id])) > 0 |
|
|
|
context = collect_graph_context() |
|
if similar_node_data.get("role") == "independent": |
|
if has_parent: |
|
new_node_id = await add_as_sibling(similar_node_id, new_query) |
|
else: |
|
new_node_id = await add_as_child(similar_node_id, new_query) |
|
else: |
|
new_node_id = await add_as_child(similar_node_id, new_query) |
|
|
|
await self.build_graph( |
|
query=new_query, |
|
parent_node_id=new_node_id, |
|
depth=1 if "SQ" in new_node_id else 2, |
|
context=context, |
|
session_id=session_id |
|
) |
|
|
|
async def build_graph(self, query: str, data: str = None, parent_node_id: str = None, |
|
depth: int = 0, threshold: float = 0.8, recurse: bool = True, |
|
context: list = None, session_id: str = None, max_tokens_allowed: int = 128000, |
|
node_data_futures: dict = None, sub_nodes_info: list = None, |
|
sub_query_ids: list = None, pre_req_nodes: list = None): |
|
"""Build a new graph structure in memory.""" |
|
async def process_node(node_id: str, sub_query: str, session_id: str, |
|
future: asyncio.Future, max_tokens_allowed: int = max_tokens_allowed): |
|
try: |
|
optimized_query = await self.search_engine.generate_optimized_query(sub_query) |
|
results = await self.search_engine.search( |
|
query=optimized_query, |
|
num_results=10, |
|
exclude_filetypes=["pdf"] |
|
) |
|
await self.emit_event("search_results_fetched", { |
|
"node_id": node_id, |
|
"sub_query": sub_query, |
|
"optimized_query": optimized_query, |
|
"search_results": results |
|
}) |
|
filtered_urls = await self.search_engine.filter_urls( |
|
sub_query, |
|
"ultra", |
|
results |
|
) |
|
await self.emit_event("search_results_filtered", { |
|
"node_id": node_id, |
|
"sub_query": sub_query, |
|
"filtered_urls": filtered_urls |
|
}) |
|
urls = [result.get('link', 'No URL') for result in filtered_urls] |
|
search_contents = await self.custom_crawler.fetch_page_contents( |
|
urls, |
|
sub_query, |
|
session_id=session_id, |
|
max_attempts=1, |
|
timeout=30 |
|
) |
|
await self.emit_event("search_contents_fetched", { |
|
"node_id": node_id, |
|
"sub_query": sub_query, |
|
"contents": search_contents |
|
}) |
|
|
|
contents = "" |
|
for k, content in enumerate(search_contents, 1): |
|
if isinstance(content, Exception): |
|
print(f"Error fetching content: {content}") |
|
elif content: |
|
contents += f"Document {k}:\n{content}\n\n" |
|
|
|
if contents.strip(): |
|
token_count = self.llm.get_num_tokens(contents) |
|
if token_count > max_tokens_allowed: |
|
contents = await self.chunking.chunker( |
|
text=contents, |
|
query=sub_query, |
|
max_tokens=max_tokens_allowed |
|
) |
|
print(f"Number of tokens in the answer: {token_count}") |
|
print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}") |
|
|
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
|
|
if node_id in node_map: |
|
idx = node_map[node_id] |
|
node_data = graph.get_node_data(idx) |
|
node_data["data"] = contents |
|
if not future.done(): |
|
future.set_result(contents) |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e: |
|
print(f"Error processing node {node_id}: {str(e)}") |
|
if not future.done(): |
|
future.set_exception(e) |
|
except Exception as e: |
|
print(f"Error processing node {node_id}: {str(e)}") |
|
if not future.done(): |
|
future.set_exception(e) |
|
raise e |
|
|
|
async def process_dependent_node(node_id: str, sub_query: str, dep_futures: list, future): |
|
try: |
|
dep_data = [await f for f in dep_futures] |
|
modified_query = await self.query_processor.modify_query( |
|
sub_query, |
|
dep_data |
|
) |
|
loop = asyncio.get_running_loop() |
|
embedding = await loop.run_in_executor( |
|
self.executor, |
|
self.model.encode, |
|
modified_query |
|
) |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
|
|
if node_id in node_map: |
|
idx = node_map[node_id] |
|
node_data = graph.get_node_data(idx) |
|
node_data["query"] = modified_query |
|
node_data["embedding"] = embedding.tolist() if hasattr(embedding, "tolist") else embedding |
|
try: |
|
if not future.done(): |
|
await process_node(node_id, modified_query, session_id, future, max_tokens_allowed) |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e: |
|
if not future.done(): |
|
future.set_exception(e) |
|
except Exception as e: |
|
if not future.done(): |
|
future.set_exception(e) |
|
raise e |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e: |
|
print(f"Error processing dependent node {node_id}: {str(e)}") |
|
if not future.done(): |
|
future.set_exception(e) |
|
except Exception as e: |
|
print(f"Error processing dependent node {node_id}: {str(e)}") |
|
if not future.done(): |
|
future.set_exception(e) |
|
raise e |
|
|
|
def create_cross_connections(): |
|
try: |
|
relationships = self.get_node_relationships(relationship_type='logical') |
|
|
|
for current_node_id, edges in relationships.items(): |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
|
|
if current_node_id not in node_map: |
|
continue |
|
|
|
idx = node_map[current_node_id] |
|
node_data = graph.get_node_data(idx) |
|
node_role = (node_data.get("role") or "").lower() |
|
|
|
if node_role == 'dependent': |
|
for source_id, target_id, edge_data in edges['in_edges']: |
|
if not source_id or source_id == self.root_node_id: |
|
continue |
|
|
|
connection = tuple(sorted([current_node_id, source_id])) |
|
if connection not in self.cross_connections: |
|
if not self.edge_exists(source_id, current_node_id): |
|
print(f"Adding cross-connection edge between {source_id} and {current_node_id}") |
|
self.add_edge(source_id, current_node_id, weight=edge_data.get('weight', 1.0), relationship_type='logical') |
|
self.cross_connections.add(connection) |
|
|
|
for source_id, target_id, edge_data in edges['out_edges']: |
|
if not target_id or target_id == self.root_node_id: |
|
continue |
|
|
|
connection = tuple(sorted([current_node_id, target_id])) |
|
if connection not in self.cross_connections: |
|
if not self.edge_exists(current_node_id, target_id): |
|
print(f"Adding cross-connection edge between {current_node_id} and {target_id}") |
|
self.add_edge(current_node_id, target_id, weight=edge_data.get('weight', 1.0), relationship_type='logical') |
|
self.cross_connections.add(connection) |
|
except Exception as e: |
|
print(f"Error creating cross connections: {str(e)}") |
|
raise |
|
|
|
if depth > 1: |
|
return |
|
|
|
if context is None: |
|
context = [] |
|
|
|
if node_data_futures is None: |
|
node_data_futures = {} |
|
if sub_nodes_info is None: |
|
sub_nodes_info = [] |
|
if sub_query_ids is None: |
|
sub_query_ids = [] |
|
if pre_req_nodes is None: |
|
pre_req_nodes = {} |
|
|
|
if parent_node_id is None: |
|
self.add_node(self.root_node_id, query, data) |
|
parent_node_id = self.root_node_id |
|
|
|
intent = await self.query_processor.get_query_intent(query) |
|
|
|
if depth == 0: |
|
response_data, sub_queries, roles, dependencies = await self.query_processor.decompose_query_with_dependencies(query, intent) |
|
else: |
|
response_data, sub_queries, roles, dependencies = await self.query_processor.decompose_query_with_dependencies(query, intent, context) |
|
|
|
if response_data: |
|
context.append(response_data) |
|
|
|
if len(sub_queries) > 1 and sub_queries[0] != query: |
|
for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)): |
|
if depth == 0: |
|
await self.emit_event("sub_query_created", { |
|
"depth": depth, |
|
"sub_query": sub_query, |
|
"role": role, |
|
"dependency": dependency, |
|
"parent_node_id": parent_node_id, |
|
}) |
|
|
|
if depth == 0: |
|
self.node_counter += 1 |
|
sub_node_id = f"SQ{self.node_counter}" |
|
else: |
|
self.sub_node_counter += 1 |
|
sub_node_id = f"SSQ{self.sub_node_counter}" |
|
|
|
sub_query_ids.append(sub_node_id) |
|
self.add_node(sub_node_id, sub_query, role=role) |
|
future = asyncio.Future() |
|
node_data_futures[sub_node_id] = future |
|
sub_nodes_info.append((sub_node_id, sub_query, role, dependency, future, depth)) |
|
|
|
if role.lower() in ['pre-requisite', 'prerequisite']: |
|
pre_req_nodes[idx] = sub_node_id |
|
|
|
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'): |
|
self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical') |
|
elif role.lower() == 'dependent': |
|
if isinstance(dependency, list) and (len(dependency) == 2 and all(isinstance(d, list) for d in dependency)): |
|
print(f"Dependency: {dependency}") |
|
prev_deps, current_deps = dependency |
|
|
|
if context and prev_deps not in [None, []]: |
|
for dep_idx in prev_deps: |
|
if dep_idx is not None: |
|
for context_data in context: |
|
if 'subqueries' in context_data and dep_idx < len(context_data['subqueries']): |
|
sub_query_data = context_data['subqueries'][dep_idx] |
|
if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data: |
|
dep_query = sub_query_data['subquery'] |
|
matching_nodes = self.find_nodes_by_properties(query=dep_query) |
|
if matching_nodes: |
|
dep_node_id = matching_nodes[0].get('node_id') |
|
score = matching_nodes[0].get('score', 0) |
|
if score >= 0.9: |
|
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') |
|
|
|
if current_deps not in [None, []]: |
|
for dep_idx in current_deps: |
|
if dep_idx < len(sub_query_ids): |
|
dep_node_id = sub_query_ids[dep_idx] |
|
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') |
|
else: |
|
raise ValueError(f"Invalid dependency index: {dep_idx}") |
|
elif len(dependency) > 0: |
|
for dep_idx in dependency: |
|
if dep_idx < len(sub_queries): |
|
dep_node_id = sub_query_ids[dep_idx] |
|
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') |
|
else: |
|
raise ValueError(f"Invalid dependency index: {dep_idx}") |
|
else: |
|
raise ValueError(f"Invalid dependency: {dependency}") |
|
else: |
|
raise ValueError(f"Unexpected role: {role}") |
|
|
|
if recurse: |
|
recursion_tasks = [] |
|
|
|
for idx, sub_query in enumerate(sub_queries): |
|
try: |
|
sub_node_id = sub_query_ids[idx] |
|
recursion_tasks.append( |
|
self.build_graph( |
|
query=sub_query, |
|
parent_node_id=sub_node_id, |
|
depth=depth + 1, |
|
threshold=threshold, |
|
recurse=recurse, |
|
context=context, |
|
session_id=session_id, |
|
node_data_futures=node_data_futures, |
|
sub_nodes_info=sub_nodes_info, |
|
sub_query_ids=sub_query_ids, |
|
pre_req_nodes=pre_req_nodes |
|
) |
|
) |
|
except Exception as e: |
|
print(f"Failed to create recursion task for sub-query {sub_query}: {e}") |
|
continue |
|
|
|
if recursion_tasks: |
|
try: |
|
await asyncio.gather(*recursion_tasks, return_exceptions=True) |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError) as e: |
|
print(f"Error during recursive processing: {e}") |
|
except Exception as e: |
|
print(f"Error during recursive processing: {e}") |
|
raise e |
|
|
|
futures = {} |
|
all_child_futures = {} |
|
process_tasks = [] |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
|
|
for (sub_node_id, sub_query, role, dependency, future, local_depth) in sub_nodes_info: |
|
idx = node_map.get(sub_node_id) |
|
has_children = False |
|
child_futures = [] |
|
if idx is not None: |
|
for (_, child_idx, edge_data) in graph.out_edges(idx): |
|
if edge_data.get("type") == "hierarchical": |
|
has_children = True |
|
child_future = node_data_futures.get(graph.get_node_data(child_idx).get("id")) |
|
if child_future: |
|
child_futures.append(child_future) |
|
if local_depth == 0: |
|
futures[sub_query] = future |
|
all_child_futures[sub_query] = child_futures |
|
if has_children: |
|
if not future.done(): |
|
future.set_result("") |
|
else: |
|
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'): |
|
process_tasks.append(process_node(sub_node_id, sub_query, session_id, future, max_tokens_allowed)) |
|
elif role.lower() == 'dependent': |
|
dep_futures = [] |
|
if isinstance(dependency, list) and len(dependency) == 2: |
|
prev_deps, current_deps = dependency |
|
if context and prev_deps not in [None, []]: |
|
for context_idx, context_data in enumerate(context): |
|
if isinstance(prev_deps, list) and context_idx < len(prev_deps): |
|
context_dep = prev_deps[context_idx] |
|
if (context_dep is not None and isinstance(context_data, dict) |
|
and 'subqueries' in context_data): |
|
if context_dep < len(context_data['subqueries']): |
|
dep_query = context_data['subqueries'][context_dep]['subquery'] |
|
matching_nodes = self.find_nodes_by_properties(query=dep_query) |
|
if matching_nodes not in [None, []]: |
|
dep_node_id = matching_nodes[0].get('node_id', None) |
|
score = float(matching_nodes[0].get('score', 0)) |
|
if score == 1.0 and dep_node_id in node_data_futures: |
|
dep_futures.append(node_data_futures[dep_node_id]) |
|
elif isinstance(prev_deps, int): |
|
if context_idx < len(context_data['subqueries']): |
|
dep_query = context_data['subqueries'][prev_deps]['subquery'] |
|
matching_nodes = self.find_nodes_by_properties(query=dep_query) |
|
if matching_nodes not in [None, []]: |
|
dep_node_id = matching_nodes[0].get('node_id', None) |
|
score = matching_nodes[0].get('score', 0) |
|
if score == 1.0 and dep_node_id in node_data_futures: |
|
dep_futures.append(node_data_futures[dep_node_id]) |
|
if current_deps not in [None, []]: |
|
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps |
|
for dep_idx in current_deps_list: |
|
if dep_idx < len(sub_query_ids): |
|
dep_node_id = sub_query_ids[dep_idx] |
|
if dep_node_id in node_data_futures: |
|
dep_futures.append(node_data_futures[dep_node_id]) |
|
process_tasks.append(process_dependent_node(sub_node_id, sub_query, dep_futures, future)) |
|
else: |
|
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'): |
|
process_tasks.append(process_node(sub_node_id, sub_query, session_id, future, max_tokens_allowed)) |
|
elif role.lower() == 'dependent': |
|
dep_futures = [] |
|
if isinstance(dependency, list) and len(dependency) == 2: |
|
prev_deps, current_deps = dependency |
|
if context and prev_deps not in [None, []]: |
|
for context_idx, context_data in enumerate(context): |
|
if isinstance(prev_deps, list) and context_idx < len(prev_deps): |
|
context_dep = prev_deps[context_idx] |
|
if (context_dep is not None and isinstance(context_data, dict) |
|
and 'subqueries' in context_data): |
|
if context_dep < len(context_data['subqueries']): |
|
dep_query = context_data['subqueries'][context_dep]['subquery'] |
|
matching_nodes = self.find_nodes_by_properties(query=dep_query) |
|
if matching_nodes not in [None, []]: |
|
dep_node_id = matching_nodes[0].get('node_id', None) |
|
score = float(matching_nodes[0].get('score', 0)) |
|
if score == 1.0 and dep_node_id in node_data_futures: |
|
dep_futures.append(node_data_futures[dep_node_id]) |
|
elif isinstance(prev_deps, int): |
|
if context_idx < len(context_data['subqueries']): |
|
dep_query = context_data['subqueries'][prev_deps]['subquery'] |
|
matching_nodes = self.find_nodes_by_properties(query=dep_query) |
|
if matching_nodes not in [None, []]: |
|
dep_node_id = matching_nodes[0].get('node_id', None) |
|
score = matching_nodes[0].get('score', 0) |
|
if score == 1.0 and dep_node_id in node_data_futures: |
|
dep_futures.append(node_data_futures[dep_node_id]) |
|
if current_deps not in [None, []]: |
|
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps |
|
for dep_idx in current_deps_list: |
|
if dep_idx < len(sub_query_ids): |
|
dep_node_id = sub_query_ids[dep_idx] |
|
if dep_node_id in node_data_futures: |
|
dep_futures.append(node_data_futures[dep_node_id]) |
|
process_tasks.append(process_dependent_node(sub_node_id, sub_query, dep_futures, future)) |
|
|
|
if process_tasks: |
|
await self.emit_event("search_process_started", { |
|
"depth": depth, |
|
"sub_queries": sub_queries, |
|
"roles": roles |
|
}) |
|
|
|
processed_sub_queries = set() |
|
for sub_query, future in futures.items(): |
|
try: |
|
parent_content = future.result().strip() |
|
except: |
|
parent_content = "" |
|
|
|
child_futures = all_child_futures.get(sub_query) |
|
any_child_done = any(cf.done() and cf.result().strip() for cf in child_futures) |
|
|
|
if parent_content or any_child_done: |
|
await self.emit_event("sub_query_processed", {"sub_query": sub_query}) |
|
processed_sub_queries.add(sub_query) |
|
|
|
await asyncio.gather(*process_tasks) |
|
|
|
if depth == 0: |
|
for sub_query, future in futures.items(): |
|
if sub_query not in processed_sub_queries: |
|
try: |
|
parent_content = future.result().strip() |
|
except: |
|
parent_content = "" |
|
|
|
child_futures = all_child_futures.get(sub_query) |
|
any_child_done = any(cf.done() and cf.result().strip() for cf in child_futures) |
|
|
|
if parent_content or any_child_done: |
|
await self.emit_event("sub_query_processed", {"sub_query": sub_query}) |
|
else: |
|
await self.emit_event("sub_query_failed", {"sub_query": sub_query}) |
|
|
|
for idx, (sub_query, future) in enumerate(futures.items(), 1): |
|
if future.done() and future.result().strip(): |
|
print(f"Sub-query {idx} processed successfully") |
|
else: |
|
child_futures = all_child_futures.get(sub_query) |
|
if any(cf.done() and cf.result().strip() for cf in child_futures): |
|
print(f"Sub-query {idx} processed successfully because of child nodes") |
|
else: |
|
print(f"Sub-query {idx} failed to process because of child nodes") |
|
|
|
print("Graph building complete, processing final tasks...") |
|
await self.emit_event("search_process_completed", { |
|
"depth": depth, |
|
"sub_queries": sub_queries, |
|
"roles": roles |
|
}) |
|
|
|
create_cross_connections() |
|
print("All cross-connections have been created!") |
|
print(f"Adding similarity edges with threshold {threshold}") |
|
|
|
graph_data = self._get_current_graph_data() |
|
node_map = graph_data["node_map"] |
|
all_node_ids = list(node_map.keys()) |
|
|
|
for i, node1 in enumerate(all_node_ids): |
|
for node2 in all_node_ids[i+1:]: |
|
if not self.edge_exists(node1, node2): |
|
self.add_edge_based_on_similarity_and_relevance(node1, node2, query, threshold) |
|
|
|
print("All similarity edges have been added!") |
|
|
|
async def process_graph( |
|
self, |
|
query: str, |
|
data: str = None, |
|
similarity_threshold: float = 0.8, |
|
relevance_threshold: float = 0.7, |
|
sub_sub_queries: bool = True, |
|
session_id: str = None, |
|
max_tokens_allowed: int = 128000 |
|
): |
|
"""Process a query and manage graph creation/modification.""" |
|
def check_query_similarity(new_query: str, similarity_threshold: float = 0.8) -> Dict[str, Any]: |
|
if self.current_graph_id is None: |
|
raise Exception("Error: No current graph ID. Cannot check query similarity.") |
|
|
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
similarities = [] |
|
|
|
if not node_map: |
|
return {"should_create_new": True} |
|
|
|
for node_id, idx in node_map.items(): |
|
node_data = graph.get_node_data(idx) |
|
|
|
if not node_data.get("query"): |
|
continue |
|
|
|
similarity = self.calculate_query_similarity(new_query, node_data.get("query")) |
|
if similarity >= similarity_threshold: |
|
similarities.append({ |
|
"node_id": node_id, |
|
"query": node_data.get("query"), |
|
"score": similarity, |
|
"role": node_data.get("role") |
|
}) |
|
|
|
if not similarities: |
|
print(f"No similar queries found above threshold {similarity_threshold}") |
|
return {"should_create_new": True} |
|
|
|
best_match = max(similarities, key=lambda x: x["score"]) |
|
|
|
rel_type = "root" |
|
if "SSQ" in best_match["node_id"]: |
|
rel_type = "sub-sub" |
|
|
|
elif "SQ" in best_match["node_id"]: |
|
rel_type = "sub" |
|
|
|
return { |
|
"most_similar_query": best_match["query"], |
|
"similarity_score": best_match["score"], |
|
"relationship_type": rel_type, |
|
"node_id": best_match["node_id"], |
|
"should_create_new": best_match["score"] < similarity_threshold |
|
} |
|
try: |
|
graphs = self.get_graphs() |
|
|
|
if not graphs: |
|
print("No existing graphs found. Creating new graph.") |
|
self.create_new_graph() |
|
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"}) |
|
await self.build_graph( |
|
query=query, |
|
data=data, |
|
threshold=relevance_threshold, |
|
recurse=sub_sub_queries, |
|
session_id=session_id, |
|
max_tokens_allowed=max_tokens_allowed |
|
) |
|
gc.collect() |
|
self.prune_edges() |
|
self.update_pagerank() |
|
self.verify_graph_integrity() |
|
self.verify_graph_consistency() |
|
return |
|
|
|
max_similarity = 0 |
|
most_similar_graph = None |
|
consolidated_graphs = {} |
|
|
|
for graph_obj in graphs: |
|
graph_info = graph_obj.get("graph_info") |
|
if not graph_info: |
|
continue |
|
|
|
graph_id = graph_info.get("graph_id") |
|
|
|
if not graph_id: |
|
continue |
|
|
|
if graph_id not in consolidated_graphs: |
|
consolidated_graphs[graph_id] = { |
|
"graph_id": graph_id, |
|
"nodes": [] |
|
} |
|
|
|
if graph_info.get("nodes"): |
|
consolidated_graphs[graph_id]["nodes"].extend(graph_info["nodes"]) |
|
|
|
for graph_id, graph_data in consolidated_graphs.items(): |
|
nodes = graph_data["nodes"] |
|
|
|
for node in nodes: |
|
if node.get("query"): |
|
similarity = self.calculate_query_similarity(query, node["query"]) |
|
|
|
if node.get("id", "").startswith("SQ"): |
|
asyncio.create_task(self.emit_event("retrieved_sub_query", { |
|
"sub_query": node["query"] |
|
})) |
|
|
|
if similarity > max_similarity: |
|
max_similarity = similarity |
|
most_similar_graph = graph_id |
|
|
|
if max_similarity >= similarity_threshold: |
|
print(f"Found similar query with score {round(max_similarity, 2)}") |
|
self.current_graph_id = most_similar_graph |
|
|
|
if round(max_similarity, 2) == 1.0: |
|
print("Loading and using existing graph") |
|
await self.emit_event("graph_operation", {"operation_type": "loading_existing_graph"}) |
|
success = self.load_graph(self.root_node_id) |
|
|
|
if not success: |
|
raise Exception("Failed to load existing graph") |
|
|
|
else: |
|
print("Checking for node-level similarity...") |
|
similarity_info = check_query_similarity( |
|
query, |
|
similarity_threshold |
|
) |
|
|
|
if similarity_info["relationship_type"] in ["sub", "sub-sub"]: |
|
print(f"Most Similar Query: {similarity_info['most_similar_query']}") |
|
print("Modifying existing graph structure") |
|
await self.emit_event("graph_operation", {"operation_type": "modifying_existing_graph"}) |
|
await self.modify_graph( |
|
query, |
|
similarity_info["node_id"], |
|
session_id=session_id |
|
) |
|
gc.collect() |
|
self.prune_edges() |
|
self.update_pagerank() |
|
self.verify_graph_integrity() |
|
self.verify_graph_consistency() |
|
|
|
else: |
|
print(f"Creating new graph for query: {query}") |
|
self.create_new_graph() |
|
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"}) |
|
await self.build_graph( |
|
query=query, |
|
data=data, |
|
threshold=relevance_threshold, |
|
recurse=sub_sub_queries, |
|
session_id=session_id, |
|
max_tokens_allowed=max_tokens_allowed |
|
) |
|
gc.collect() |
|
self.prune_edges() |
|
self.update_pagerank() |
|
self.verify_graph_integrity() |
|
self.verify_graph_consistency() |
|
except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, RetryError): |
|
pass |
|
except Exception as e: |
|
print(f"Error in process_graph: {str(e)}") |
|
raise |
|
|
|
def add_edge_based_on_similarity_and_relevance(self, node1_id: str, node2_id: str, query: str, threshold: float = 0.8): |
|
"""Add edges based on node similarity and relevance.""" |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
|
|
if node1_id not in node_map or node2_id not in node_map: |
|
return |
|
|
|
idx1 = node_map[node1_id] |
|
idx2 = node_map[node2_id] |
|
node1_data = graph.get_node_data(idx1) |
|
node2_data = graph.get_node_data(idx2) |
|
|
|
if not all([node1_data.get("embedding"), node2_data.get("embedding"), node1_data.get("data"), node2_data.get("data")]): |
|
return |
|
|
|
similarity = self.cosine_similarity(node1_data["embedding"], node2_data["embedding"]) |
|
query_relevance1 = self.calculate_relevance(query, node1_data["data"]) |
|
query_relevance2 = self.calculate_relevance(query, node2_data["data"]) |
|
node_relevance = self.calculate_relevance(node1_data["data"], node2_data["data"]) |
|
weight = (similarity + query_relevance1 + query_relevance2 + node_relevance) / 4 |
|
|
|
if weight >= threshold: |
|
self.add_edge(node1_id, node2_id, weight=weight, relationship_type='similarity_and_relevance') |
|
print(f"Added edge between {node1_id} and {node2_id} with type similarity_and_relevance and weight {weight}") |
|
|
|
def calculate_relevance(self, data1: str, data2: str) -> float: |
|
"""Calculate relevance between two data strings.""" |
|
try: |
|
if not data1 or not data2: |
|
return 0.0 |
|
|
|
P, R, F1 = self.scorer.score([data1], [data2]) |
|
return F1.mean().item() |
|
except Exception as e: |
|
print(f"Error calculating relevance: {str(e)}") |
|
return 0.0 |
|
|
|
def calculate_query_similarity(self, query1: str, query2: str) -> float: |
|
"""Calculate similarity between two queries.""" |
|
try: |
|
embedding1 = self.model.encode(query1).tolist() |
|
embedding2 = self.model.encode(query2).tolist() |
|
return self.cosine_similarity(embedding1, embedding2) |
|
except Exception as e: |
|
print(f"Error calculating query similarity: {str(e)}") |
|
return 0.0 |
|
|
|
def get_similarities_and_relevance(self, threshold: float = 0.8) -> list: |
|
"""Get similarities and relevance between nodes.""" |
|
try: |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
nodes = [] |
|
|
|
for node_id, idx in node_map.items(): |
|
node_data = graph.get_node_data(idx) |
|
nodes.append({ |
|
"id": node_data.get("id"), |
|
"embedding": node_data.get("embedding"), |
|
"data": node_data.get("data") |
|
}) |
|
|
|
similarities = [] |
|
for i, node1 in enumerate(nodes): |
|
for node2 in nodes[i + 1:]: |
|
similarity = self.cosine_similarity(node1["embedding"], node2["embedding"]) |
|
relevance = self.calculate_relevance(node1["data"], node2["data"]) |
|
weight = (similarity + relevance) / 2 |
|
|
|
if weight >= threshold: |
|
similarities.append({ |
|
'node1': node1["id"], |
|
'node2': node2["id"], |
|
'similarity': similarity, |
|
'relevance': relevance, |
|
'weight': weight |
|
}) |
|
|
|
return similarities |
|
except Exception as e: |
|
print(f"Error getting similarities and relevance: {str(e)}") |
|
return [] |
|
|
|
def get_node_relationships(self, node_id=None, depth=None, role=None, relationship_type=None): |
|
"""Get relationships between nodes with filtering options.""" |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
relationships = {} |
|
|
|
for n_id, idx in node_map.items(): |
|
if n_id == self.root_node_id: |
|
continue |
|
|
|
node_data = graph.get_node_data(idx) |
|
|
|
if node_id and n_id != node_id: |
|
continue |
|
|
|
if role and node_data.get("role") != role: |
|
continue |
|
|
|
in_edges = [] |
|
for u, v, edge_data in graph.in_edges(idx): |
|
source_id = graph.get_node_data(u).get("id") |
|
in_edges.append((source_id, n_id, {"weight": edge_data.get("weight"), "type": edge_data.get("type")})) |
|
|
|
out_edges = [] |
|
for u, v, edge_data in graph.out_edges(idx): |
|
target_id = graph.get_node_data(v).get("id") |
|
out_edges.append((n_id, target_id, {"weight": edge_data.get("weight"), "type": edge_data.get("type")})) |
|
|
|
relationships[n_id] = {"in_edges": in_edges, "out_edges": out_edges} |
|
|
|
return relationships |
|
|
|
def find_nodes_by_properties(self, query: str = None, embedding: list = None, |
|
node_data: dict = None, similarity_threshold: float = 0.8) -> list: |
|
"""Find nodes based on properties.""" |
|
try: |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
matching_nodes = [] |
|
|
|
for n_id, idx in node_map.items(): |
|
data = graph.get_node_data(idx) |
|
match_score = 0 |
|
matches = 0 |
|
|
|
if query and query.lower() in data.get("query", "").lower(): |
|
match_score += 1 |
|
matches += 1 |
|
|
|
if embedding and "embedding" in data: |
|
sim = self.cosine_similarity(embedding, data["embedding"]) |
|
|
|
if sim >= similarity_threshold: |
|
match_score += sim |
|
matches += 1 |
|
|
|
if node_data: |
|
data_matches = sum(1 for k, v in node_data.items() if k in data and data[k] == v) |
|
|
|
if data_matches > 0: |
|
match_score += data_matches / len(node_data) |
|
matches += 1 |
|
|
|
if matches > 0: |
|
matching_nodes.append({ |
|
"node_id": n_id, |
|
"score": match_score / matches, |
|
"data": data |
|
}) |
|
|
|
matching_nodes.sort(key=lambda x: x["score"], reverse=True) |
|
|
|
return matching_nodes |
|
except Exception as e: |
|
print(f"Error finding nodes by properties: {str(e)}") |
|
raise |
|
|
|
def query_graph(self, query: str) -> str: |
|
"""Query the graph for a specific query, collecting data from the entire relevant subgraph.""" |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
target_node_id = None |
|
|
|
for n_id, idx in node_map.items(): |
|
if graph.get_node_data(idx).get("query") == query: |
|
target_node_id = n_id |
|
break |
|
|
|
if not target_node_id: |
|
raise ValueError(f"Query node not found for: {query}") |
|
|
|
datas = [] |
|
start_idx = node_map[target_node_id] |
|
visited = set() |
|
stack = [start_idx] |
|
|
|
while stack: |
|
current = stack.pop() |
|
|
|
if current in visited: |
|
continue |
|
visited.add(current) |
|
current_data = graph.get_node_data(current) |
|
|
|
if current_data.get("data") and current_data.get("data").strip(): |
|
datas.append(current_data.get("data").strip()) |
|
|
|
for neighbor in graph.neighbors(current): |
|
if neighbor not in visited: |
|
stack.append(neighbor) |
|
|
|
if not datas: |
|
print(f"No data found for: {query}") |
|
return "" |
|
|
|
return "\n\n".join([f"Data {i+1}:\n{data}" for i, data in enumerate(datas)]) |
|
|
|
def prune_edges(self, max_edges: int = 1000): |
|
"""Prune excess edges while preserving node data.""" |
|
print(f"Pruning edges to maximum {max_edges} edges...") |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
all_edges = list(graph.edge_list()) |
|
current_edges = len(all_edges) |
|
|
|
if current_edges > max_edges: |
|
sorted_edges = sorted(all_edges, key=lambda x: x[2].get("weight", 1.0), reverse=True) |
|
edges_to_keep = set() |
|
|
|
for edge in sorted_edges[:max_edges]: |
|
edges_to_keep.add((edge[0], edge[1])) |
|
|
|
edges_to_remove = [] |
|
for edge in all_edges: |
|
if (edge[0], edge[1]) not in edges_to_keep: |
|
edges_to_remove.append((edge[0], edge[1])) |
|
|
|
for u, v in edges_to_remove: |
|
try: |
|
graph.remove_edge(u, v) |
|
except Exception as e: |
|
print(f"Error removing edge from {u} to {v}: {e}") |
|
|
|
print(f"Pruned edges. Kept top {max_edges} edges by weight.") |
|
|
|
print("No pruning required. Current edge count is within limits.") |
|
|
|
def update_pagerank(self): |
|
"""Update PageRank values using Rustworkx's pagerank algorithm.""" |
|
if not self.current_graph_id: |
|
print("No current graph selected. Cannot compute PageRank.") |
|
return |
|
|
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
|
|
try: |
|
pr = rx.pagerank(graph, weight_fn=lambda e: e.get("weight", 1.0)) |
|
node_map = graph_data["node_map"] |
|
|
|
for n_id, idx in node_map.items(): |
|
node_data = graph.get_node_data(idx) |
|
node_data["pagerank"] = pr[idx] |
|
|
|
print("PageRank updated successfully") |
|
except Exception as e: |
|
print(f"Error updating PageRank: {str(e)}") |
|
raise |
|
|
|
def display_graph(self): |
|
"""Display the graph using PyVis.""" |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
net = Network(height="530px", width="100%", directed=True, bgcolor="#222222", font_color="white") |
|
net.options = {"physics": {"enabled": False}} |
|
all_nodes = set() |
|
all_edges = [] |
|
|
|
for (u, v), edge_data in zip(graph.edge_list(), graph.edges()): |
|
source_data = graph.get_node_data(u) |
|
target_data = graph.get_node_data(v) |
|
source_id = source_data.get("id") |
|
target_id = target_data.get("id") |
|
source_tooltip = f"Query: {source_data.get('query', 'N/A')}" |
|
target_tooltip = f"Query: {target_data.get('query', 'N/A')}" |
|
|
|
if source_id not in all_nodes: |
|
net.add_node(source_id, label=source_id, title=source_tooltip, size=20, color="#00cc66") |
|
all_nodes.add(source_id) |
|
|
|
if target_id not in all_nodes: |
|
net.add_node(target_id, label=target_id, title=target_tooltip, size=20, color="#00cc66") |
|
all_nodes.add(target_id) |
|
|
|
edge_type = edge_data.get("type", "N/A") |
|
edge_weight = edge_data.get("weight", "N/A") |
|
edge_tooltip = f"Weight: {edge_weight}" |
|
all_edges.append({ |
|
"from": source_id, |
|
"to": target_id, |
|
"label": edge_type, |
|
"title": edge_tooltip |
|
}) |
|
|
|
for edge in all_edges: |
|
net.add_edge(edge["from"], edge["to"], title=edge["title"], color="#cccccc") |
|
|
|
net.options["layout"] = {"improvedLayout": True} |
|
net.options["interaction"] = {"dragNodes": True} |
|
|
|
net.save_graph("temp_graph.html") |
|
|
|
with open("temp_graph.html", "r", encoding="utf-8") as f: |
|
html_str = f.read() |
|
os.remove("temp_graph.html") |
|
return html_str |
|
|
|
def verify_graph_integrity(self): |
|
"""Verify and fix graph integrity issues.""" |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
orphaned = [] |
|
|
|
for n_id, idx in node_map.items(): |
|
if not graph.in_edges(idx) and not graph.out_edges(idx): |
|
orphaned.append(n_id) |
|
|
|
if orphaned: |
|
print(f"Found orphaned nodes: {orphaned}") |
|
|
|
invalid_edges = [] |
|
for u, v in graph.edge_list(): |
|
target_data = graph.get_node_data(v) |
|
|
|
if target_data.get("graph_id") != self.current_graph_id: |
|
invalid_edges.append((graph.get_node_data(u).get("id"), target_data.get("id"))) |
|
|
|
if invalid_edges: |
|
print(f"Found invalid edges: {invalid_edges}") |
|
edges_to_remove = [] |
|
|
|
for u, v in graph.edge_list(): |
|
if graph.get_node_data(v).get("graph_id") != self.current_graph_id: |
|
edges_to_remove.append((u, v)) |
|
|
|
for u, v in edges_to_remove: |
|
try: |
|
graph.remove_edge(u, v) |
|
except Exception as e: |
|
Exception(f"Error removing invalid edge from {u} to {v}: {e}") |
|
|
|
print("Graph integrity verified successfully") |
|
|
|
return True |
|
|
|
def verify_graph_consistency(self): |
|
"""Verify consistency of the in-memory graph.""" |
|
graph_data = self._get_current_graph_data() |
|
graph = graph_data["graph"] |
|
node_map = graph_data["node_map"] |
|
|
|
for n_id, idx in node_map.items(): |
|
node_data = graph.get_node_data(idx) |
|
|
|
if node_data.get("id") is None or node_data.get("query") is None: |
|
raise ValueError("Found nodes with missing required properties") |
|
|
|
for edge_data in graph.edges(): |
|
if edge_data.get("type") is None or edge_data.get("weight") is None: |
|
raise ValueError("Found relationships with missing required properties") |
|
|
|
print("Graph consistency verified successfully") |
|
|
|
return True |
|
|
|
async def close(self): |
|
"""Properly cleanup all resources.""" |
|
try: |
|
if hasattr(self, 'executor'): |
|
self.executor.shutdown(wait=True) |
|
|
|
if hasattr(self, 'crawler'): |
|
await asyncio.shield(self.crawler.cleanup_expired_sessions()) |
|
await asyncio.shield(self.crawler.cleanup_browser_context(getattr(self, "session_id", None))) |
|
except Exception as e: |
|
print(f"Error during cleanup: {e}") |
|
|
|
@staticmethod |
|
def cosine_similarity(v1: List[float], v2: List[float]) -> float: |
|
"""Calculate cosine similarity between two vectors.""" |
|
try: |
|
v1_array = np.array(v1) |
|
v2_array = np.array(v2) |
|
return np.dot(v1_array, v2_array) / (np.linalg.norm(v1_array) * np.linalg.norm(v2_array)) |
|
except Exception as e: |
|
print(f"Error calculating cosine similarity: {str(e)}") |
|
return 0.0 |
|
|
|
if __name__ == "__main__": |
|
import os |
|
from dotenv import load_dotenv |
|
from src.reasoning.reasoner import Reasoner |
|
from src.evaluation.evaluator import Evaluator |
|
|
|
load_dotenv() |
|
|
|
graph_search = GraphRAG(num_workers=24) |
|
evaluator = Evaluator() |
|
reasoner = Reasoner() |
|
|
|
async def test_graph_search(): |
|
|
|
queries = [ |
|
"""In the context of global economic recovery and energy security concerns, provide an in-depth comparative assessment of the renewable energy policies among G20 countries. |
|
Specifically, examine how short-term economic stimulus measures intersect with long-term decarbonization commitments, including: |
|
1. Carbon pricing mechanisms |
|
2. Subsidies for emerging technologies (such as green hydrogen and battery storage) |
|
3. Cross-border climate finance initiatives |
|
|
|
Highlight the unique challenges faced by both advanced and emerging economies in addressing: |
|
1. Energy poverty |
|
2. Supply chain disruptions |
|
3. Geopolitical tensions (e.g., the Russia-Ukraine conflict) |
|
|
|
Discuss how these factors influence policy effectiveness, and evaluate the degree to which each country is on track to meet—or exceed—its Paris Agreement targets. |
|
Note any significant policy gaps, regional collaborations, or innovative best practices. |
|
Lastly, provide a forward-looking perspective on how these renewable energy strategies may evolve over the next decade, considering: |
|
1. Technological breakthroughs |
|
2. Global market trends |
|
3. Potential climate-related disasters |
|
|
|
Present your analysis as a detailed, well-formatted report.""", |
|
"""Analyse the impact of 'hot-money' on the value of Indian Rupee and answer the following questions:- |
|
1. How does it affect the exchange rate? |
|
2. How can it be mitigated/eliminated? |
|
3. Why is it a problem? |
|
4. What are the consequences? |
|
5. What are the alternatives? |
|
- Evaluate the alternatives for pros and cons. |
|
- Evaluate the impact of alternatives on the exchange rate. |
|
- How can they be implemented? |
|
- What are the consequences of each alternative? |
|
- Evaluate the feasibility of the alternatives. |
|
- Pick top 5 alternatives and justify your choices in detail. |
|
6. What are the implications for the Indian economy? Furthermore:- |
|
- Evaluate the impact of the chosen alternatives on the Indian economy.""", |
|
"""Inflation has been an intrinsic past of human civilization since the very beginning. Answer the following questions:- |
|
1. How true is the above statement? |
|
2. What are the causes of inflation? |
|
3. What are the consequences of inflation? |
|
4. Can we completely eliminate inflation?""", |
|
"""Perform a detailed comparison between the ancient Greece and Roman civilizations. |
|
1. What were the key differences between the two civilizations? |
|
- Evaluate the differences in governance, society, and culture |
|
- Evaluate the differences in economy, trade, and military |
|
- Evaluate the differences in technology and infrastructure |
|
2. What were the similarities between the two civilizations? |
|
- Evaluate the similarities in governance, society, and culture |
|
- Evaluate the similarities in economy, trade, and military |
|
- Evaluate the similarities in technology and infrastructure |
|
3. How did these two civilizations influence each other? |
|
- Evaluate the influence of one civilization on the other |
|
4. How did these two civilizations influence the modern world? |
|
5. Was there another civilization that influenced these two? If yes, how?""", |
|
"""Evaluate the long-term effects of colonialism on economic development in Asia:- |
|
1. Include case studies of at least five different countries |
|
2. Analyze how these effects differ based on colonial power, time of independence, and resource distribution |
|
- Evaluate the impact of colonialism on the economy of the country |
|
- Evaluate the impact of colonialism on the economy of the region |
|
- Evaluate the impact of colonialism on the economy of the world |
|
3. How do these effects compare to Africa?""" |
|
] |
|
follow_on_queries = [ |
|
"How is 'hot-money' related to the current economic situation in India?", |
|
"What is inflation?", |
|
"Did ancient Greece and Rome have any impact on modern democracy? If yes, how?", |
|
"Did colonialism have any impact on the trade between Africa and Asia, both in colonial and post-colonial times? If yes, how?" |
|
] |
|
|
|
while True: |
|
print("\n\nEnter query (finish input with an empty line):") |
|
query_lines = [] |
|
|
|
while True: |
|
line = input() |
|
|
|
if line.strip() == "": |
|
break |
|
query_lines.append(line) |
|
|
|
query = "\n".join(query_lines).strip() |
|
|
|
if query.strip().lower() == "exit": |
|
break |
|
print("\n\n" + "="*15 + " Processing Query " + "="*15 + "\n\n") |
|
|
|
await graph_search.process_graph(query, similarity_threshold=0.8, relevance_threshold=0.8) |
|
|
|
answer = graph_search.query_graph(query) |
|
|
|
response = "" |
|
async for chunk in reasoner.answer(query, answer): |
|
response += chunk |
|
print(response, end="", flush=True) |
|
|
|
graph_search.display_graph() |
|
|
|
evaluation = await evaluator.evaluate_response(query, response, [answer]) |
|
print(f"Faithfulness: {evaluation['faithfulness']}") |
|
print(f"Answer Relevancy: {evaluation['answer relevancy']}") |
|
print(f"Context Utilization: {evaluation['contextual recall']}") |
|
|
|
await graph_search.close() |
|
|
|
asyncio.run(test_graph_search()) |