|
import os |
|
import gc |
|
import time |
|
import asyncio |
|
import torch |
|
import uuid |
|
from contextlib import contextmanager |
|
from neo4j import GraphDatabase |
|
from pyvis.network import Network |
|
from src.query_processing.late_chunking.late_chunker import LateChunker |
|
from src.query_processing.query_processor import QueryProcessor |
|
from src.reasoning.reasoner import Reasoner |
|
from src.utils.api_key_manager import APIKeyManager |
|
from src.search.search_engine import SearchEngine |
|
from src.crawl.crawler import Crawler, CustomCrawler |
|
from sentence_transformers import SentenceTransformer |
|
from bert_score.scorer import BERTScorer |
|
import numpy as np |
|
from concurrent.futures import ThreadPoolExecutor |
|
from typing import List, Dict, Any |
|
|
|
class Neo4jGraphRAG: |
|
def __init__(self, num_workers: int = 1): |
|
"""Initialize Neo4j connection and required components.""" |
|
|
|
self.neo4j_uri = os.getenv("NEO4J_URI") |
|
self.neo4j_user = os.getenv("NEO4J_USER") |
|
self.neo4j_password = os.getenv("NEO4J_PASSWORD") |
|
self.driver = GraphDatabase.driver( |
|
self.neo4j_uri, |
|
auth=(self.neo4j_user, self.neo4j_password) |
|
) |
|
|
|
|
|
self.num_workers = num_workers |
|
self.search_engine = SearchEngine() |
|
self.query_processor = QueryProcessor() |
|
self.reasoner = Reasoner() |
|
|
|
self.custom_crawler = CustomCrawler(max_concurrent_requests=1000) |
|
self.chunking = LateChunker() |
|
self.llm = APIKeyManager().get_llm() |
|
|
|
|
|
self.model = SentenceTransformer( |
|
"dunzhang/stella_en_400M_v5", |
|
trust_remote_code=True, |
|
device="cuda" if torch.cuda.is_available() else "cpu" |
|
) |
|
self.scorer = BERTScorer( |
|
model_type="roberta-base", |
|
lang="en", |
|
rescale_with_baseline=True, |
|
device= "cuda" if torch.cuda.is_available() else "cpu" |
|
) |
|
|
|
|
|
self.root_node_id = "QR" |
|
self.node_counter = 0 |
|
self.sub_node_counter = 0 |
|
self.cross_connections = set() |
|
|
|
|
|
self.current_graph_id = None |
|
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=self.num_workers) |
|
|
|
|
|
self.on_event_callback = None |
|
|
|
def set_on_event_callback(self, callback): |
|
"""Register a single callback to be triggered for various event types.""" |
|
self.on_event_callback = callback |
|
|
|
async def emit_event(self, event_type: str, data: dict): |
|
"""Helper method to safely emit an event if a callback is registered.""" |
|
if self.on_event_callback: |
|
|
|
if asyncio.iscoroutinefunction(self.on_event_callback): |
|
|
|
return await self.on_event_callback(event_type, data) |
|
else: |
|
return self.on_event_callback(event_type, data) |
|
|
|
@contextmanager |
|
def transaction(self, max_retries: int = 1): |
|
"""Synchronous context manager for Neo4j transactions.""" |
|
session = self.driver.session() |
|
retry_count = 0 |
|
|
|
while True: |
|
try: |
|
tx = session.begin_transaction() |
|
try: |
|
yield tx |
|
tx.commit() |
|
break |
|
except Exception as e: |
|
tx.rollback() |
|
raise e |
|
except Exception as e: |
|
retry_count += 1 |
|
if retry_count >= max_retries: |
|
print(f"Transaction failed after {max_retries} attempts: {str(e)}") |
|
raise e |
|
print(f"Transaction failed, retrying ({retry_count}/{max_retries}): {str(e)}") |
|
time.sleep(1) |
|
finally: |
|
session.close() |
|
|
|
def initialize_schema(self): |
|
"""Check and initialize database schema.""" |
|
constraint_node_id_per_graph = None |
|
index_node_query = None |
|
index_node_role = None |
|
constraint_graph_id = None |
|
index_graph_created = None |
|
constraint_node_graph = None |
|
|
|
try: |
|
with self.transaction() as tx: |
|
|
|
constraint_node_id_per_graph = tx.run(""" |
|
SHOW CONSTRAINTS |
|
WHERE name = 'constraint_node_id_per_graph' |
|
""").data() |
|
|
|
index_node_role = tx.run(""" |
|
SHOW INDEXES |
|
WHERE name = 'index_node_role' |
|
""").data() |
|
|
|
index_node_graph_id = tx.run(""" |
|
SHOW INDEXES |
|
WHERE name = 'index_node_graph_id' |
|
""").data() |
|
|
|
constraint_graph_id = tx.run(""" |
|
SHOW CONSTRAINTS |
|
WHERE name = 'constraint_graph_id' |
|
""").data() |
|
|
|
index_graph_created = tx.run(""" |
|
SHOW INDEXES |
|
WHERE name = 'index_graph_created' |
|
""").data() |
|
|
|
constraint_node_graph = tx.run(""" |
|
SHOW CONSTRAINTS |
|
WHERE name = 'constraint_node_graph' |
|
""").data() |
|
|
|
if constraint_node_id_per_graph and index_node_role and \ |
|
index_node_graph_id and constraint_graph_id and index_graph_created and constraint_node_graph: |
|
print("Database schema already initialized") |
|
return |
|
|
|
print("Initializing database schema...") |
|
|
|
|
|
if not constraint_node_id_per_graph: |
|
tx.run(""" |
|
CREATE CONSTRAINT constraint_node_id_per_graph IF NOT EXISTS |
|
FOR (n:Node) |
|
REQUIRE (n.id, n.graph_id) IS UNIQUE |
|
""") |
|
|
|
if not index_node_role: |
|
tx.run(""" |
|
CREATE INDEX index_node_role IF NOT EXISTS FOR (n:Node) |
|
ON (n.role) |
|
""") |
|
|
|
if not index_node_graph_id: |
|
tx.run(""" |
|
CREATE INDEX index_node_graph_id IF NOT EXISTS FOR (n:Node) |
|
ON (n.graph_id) |
|
""") |
|
|
|
|
|
if not constraint_graph_id: |
|
tx.run(""" |
|
CREATE CONSTRAINT constraint_graph_id IF NOT EXISTS |
|
FOR (g:Graph) |
|
REQUIRE g.id IS UNIQUE |
|
""") |
|
|
|
if not index_graph_created: |
|
tx.run(""" |
|
CREATE INDEX index_graph_created IF NOT EXISTS FOR (g:Graph) |
|
ON (g.created) |
|
""") |
|
|
|
if not constraint_node_graph: |
|
tx.run(""" |
|
CREATE CONSTRAINT constraint_node_graph IF NOT EXISTS |
|
FOR (n:Node) |
|
REQUIRE n.graph_id IS NOT NULL |
|
""") |
|
|
|
print("Database schema initialization complete") |
|
|
|
except Exception as e: |
|
print(f"Error ensuring schema exists: {str(e)}") |
|
raise |
|
|
|
def add_node(self, node_id: str, query: str, data: str = "", role: str = None): |
|
"""Add a node to the current graph.""" |
|
if self.current_graph_id is None: |
|
raise Exception("Error: No current graph selected") |
|
|
|
try: |
|
with self.transaction() as tx: |
|
|
|
embedding = self.model.encode(query).tolist() |
|
|
|
|
|
result = tx.run( |
|
""" |
|
MERGE (n:Node {id: $node_id, graph_id: $graph_id}) |
|
SET n.query = $node_query, |
|
n.embedding = $embedding, |
|
n.data = $data, |
|
n.role = $role |
|
""", |
|
node_id=node_id, |
|
graph_id=self.current_graph_id, |
|
node_query=query, |
|
embedding=embedding, |
|
data=data, |
|
role=role |
|
) |
|
print(f"Added node '{node_id}' to graph '{self.current_graph_id}' with role '{role}' and query: '{query}'") |
|
|
|
except Exception as e: |
|
print(f"Error adding node '{node_id}' to graph '{self.current_graph_id}' with role '{role}' and query: '{query}': {str(e)}") |
|
raise |
|
|
|
def add_edge(self, node1: str, node2: str, weight: float = 1.0, relationship_type: str = None): |
|
"""Add an edge between two nodes in a way that preserves a DAG structure in the graph""" |
|
if self.current_graph_id is None: |
|
raise Exception("Error: No current graph selected") |
|
|
|
|
|
if node1 == node2: |
|
print(f"Cannot add edge to the same node {node1}!") |
|
return |
|
|
|
try: |
|
with self.transaction() as tx: |
|
|
|
check_path = tx.run( |
|
""" |
|
MATCH (start:Node {id: $node2, graph_id: $graph_id}) |
|
MATCH (end:Node {id: $node1, graph_id: $graph_id}) |
|
// If there's any path of length >= 0 from 'start' to 'end', |
|
// then creating (end)->(start) would introduce a cycle. |
|
WHERE (start)-[:RELATION*0..]->(end) |
|
RETURN COUNT(start) AS pathExists |
|
""", |
|
node1=node1, |
|
node2=node2, |
|
graph_id=self.current_graph_id |
|
) |
|
path_count = check_path.single()["pathExists"] |
|
|
|
if path_count > 0: |
|
print(f"An edge between {node1} -> {node2} already exists!") |
|
return |
|
|
|
|
|
tx.run( |
|
""" |
|
MATCH (a:Node {id: $node1, graph_id: $graph_id}) |
|
MATCH (b:Node {id: $node2, graph_id: $graph_id}) |
|
MERGE (a)-[r:RELATION {type: $rel_type}]->(b) |
|
SET r.weight = $weight |
|
""", |
|
node1=node1, |
|
node2=node2, |
|
graph_id=self.current_graph_id, |
|
rel_type=relationship_type, |
|
weight=weight |
|
) |
|
|
|
print( |
|
f"Added edge between '{node1}' and '{node2}' in graph " |
|
f"'{self.current_graph_id}' (type='{relationship_type}', weight={weight})" |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error adding edge between '{node1}' and '{node2}': {str(e)}") |
|
raise |
|
|
|
def edge_exists(self, node1: str, node2: str) -> bool: |
|
"""Check if an edge exists between two nodes.""" |
|
try: |
|
with self.transaction() as tx: |
|
result = tx.run( |
|
""" |
|
MATCH (a:Node {id: $node1})-[r:RELATION]-(b:Node {id: $node2}) |
|
RETURN COUNT(r) as count |
|
""", |
|
node1=node1, |
|
node2=node2 |
|
) |
|
return result.single()["count"] > 0 |
|
|
|
except Exception as e: |
|
print(f"Error checking edge existence between {node1} and {node2}: {str(e)}") |
|
raise |
|
|
|
def graph_exists(self) -> bool: |
|
"""Check if a graph exists in Neo4j.""" |
|
try: |
|
with self.transaction() as tx: |
|
result = tx.run(""" |
|
MATCH (n:Node) |
|
RETURN count(n) > 0 as has_nodes |
|
""") |
|
return result.single()["has_nodes"] |
|
except Exception as e: |
|
print(f"Error checking graph existence: {str(e)}") |
|
raise |
|
|
|
def get_graphs(self) -> list: |
|
"""Get detailed information about all existing graphs and their nodes.""" |
|
try: |
|
with self.transaction() as tx: |
|
result = tx.run( |
|
""" |
|
MATCH (g:Graph) |
|
OPTIONAL MATCH (n:Node {graph_id: g.id})-[r:RELATION]->(:Node) |
|
WITH g, collect(DISTINCT n) AS nodes, collect(DISTINCT r) AS rels |
|
RETURN { |
|
graph_id: g.id, |
|
created: g.created, |
|
updated: g.updated, |
|
node_count: size(nodes), |
|
edge_count: size(rels), |
|
nodes: [node IN nodes | { |
|
id: node.id, |
|
query: node.query, |
|
data: node.data, |
|
role: node.role, |
|
pagerank: node.pagerank |
|
}] |
|
} as graph_info |
|
ORDER BY g.created DESC |
|
""" |
|
) |
|
return list(result) |
|
except Exception as e: |
|
print(f"Error getting graphs: {str(e)}") |
|
raise |
|
|
|
def select_graph(self, graph_id: str) -> bool: |
|
"""Select a specific graph as the current working graph.""" |
|
try: |
|
with self.transaction() as tx: |
|
result = tx.run(""" |
|
MATCH (g:Graph {id: $graph_id}) |
|
RETURN g |
|
""", graph_id=graph_id) |
|
|
|
if result.single(): |
|
self.current_graph_id = graph_id |
|
return True |
|
return False |
|
|
|
except Exception as e: |
|
print(f"Error selecting graph: {str(e)}") |
|
raise |
|
|
|
def create_new_graph(self) -> str: |
|
"""Create a new graph instance and its ID.""" |
|
try: |
|
with self.transaction() as tx: |
|
graph_id = str(uuid.uuid4()) |
|
tx.run(""" |
|
CREATE (g:Graph { |
|
id: $graph_id, |
|
created: datetime(), |
|
updated: datetime() |
|
}) |
|
""", graph_id=graph_id) |
|
|
|
self.current_graph_id = graph_id |
|
|
|
except Exception as e: |
|
print(f"Error creating new graph: {str(e)}") |
|
raise |
|
|
|
def load_graph(self, node_id: str) -> bool: |
|
"""Load an existing graph structure from Neo4j based on node ID.""" |
|
|
|
def extract_number(node_id: str) -> int: |
|
try: |
|
|
|
num_str = ''.join(filter(str.isdigit, node_id)) |
|
return int(num_str) if num_str else 0 |
|
except ValueError: |
|
print(f"Warning: Could not extract number from node ID: {node_id}") |
|
return 0 |
|
|
|
try: |
|
with self.driver.session() as session: |
|
|
|
tx = session.begin_transaction() |
|
|
|
try: |
|
|
|
result = tx.run(""" |
|
MATCH path = (n:Node)-[r:RELATION*0..]->(m:Node) |
|
WHERE n.id = $node_id |
|
RETURN DISTINCT n, r, m, |
|
length(path) as depth, |
|
[rel in r | type(rel)] as rel_types, |
|
[rel in r | rel.weight] as weights |
|
""", node_id=node_id) |
|
|
|
|
|
self.node_counter = 0 |
|
self.sub_node_counter = 0 |
|
self.cross_connections.clear() |
|
|
|
|
|
processed_nodes = set() |
|
|
|
|
|
for record in result: |
|
|
|
if record["n"]["id"] not in processed_nodes: |
|
node_id = record["n"]["id"] |
|
if "SQ" in node_id: |
|
current_num = extract_number(node_id) |
|
self.node_counter = max(self.node_counter, current_num) |
|
elif "SSQ" in node_id: |
|
current_num = extract_number(node_id) |
|
self.sub_node_counter = max(self.sub_node_counter, current_num) |
|
processed_nodes.add(node_id) |
|
|
|
if record["m"]["id"] not in processed_nodes: |
|
node_id = record["m"]["id"] |
|
if "SQ" in node_id: |
|
current_num = extract_number(node_id) |
|
self.node_counter = max(self.node_counter, current_num) |
|
elif "SSQ" in node_id: |
|
current_num = extract_number(node_id) |
|
self.sub_node_counter = max(self.sub_node_counter, current_num) |
|
processed_nodes.add(node_id) |
|
|
|
|
|
self.node_counter += 1 |
|
self.sub_node_counter += 1 |
|
|
|
|
|
result = tx.run(""" |
|
MATCH (n:Node)-[r:RELATION]->(m:Node) |
|
WHERE r.type = 'logical' |
|
RETURN n.id as source, m.id as target |
|
""") |
|
|
|
for record in result: |
|
connection = tuple(sorted([record["source"], record["target"]])) |
|
self.cross_connections.add(connection) |
|
|
|
tx.commit() |
|
print(f"Successfully loaded graph. Current counters - Node: {self.node_counter}, Sub: {self.sub_node_counter}") |
|
return True |
|
|
|
except Exception as e: |
|
tx.rollback() |
|
print(f"Transaction error while loading graph: {str(e)}") |
|
return False |
|
|
|
except Exception as e: |
|
print(f"Error loading graph: {str(e)}") |
|
return False |
|
|
|
async def modify_graph(self, new_query: str, similar_node_id: str, session_id: str = None): |
|
"""Modify an existing graph structure by integrating a new query.""" |
|
|
|
|
|
async def add_as_sibling(node_id: str, query: str): |
|
with self.transaction() as tx: |
|
result = tx.run(""" |
|
MATCH (n:Node)<-[r:RELATION]-(parent:Node) |
|
WHERE n.id = $node_id |
|
RETURN parent.id as parent_id, |
|
parent.query as parent_query, |
|
r.type as rel_type |
|
""", node_id=node_id) |
|
|
|
parent_data = result.single() |
|
if not parent_data: |
|
raise ValueError(f"No parent found for node {node_id}") |
|
|
|
if "SQ" in node_id: |
|
self.node_counter += 1 |
|
new_node_id = f"SQ{self.node_counter}" |
|
else: |
|
self.sub_node_counter += 1 |
|
new_node_id = f"SSQ{self.sub_node_counter}" |
|
|
|
self.add_node( |
|
node_id=new_node_id, |
|
query=query, |
|
role="independent" |
|
) |
|
self.add_edge( |
|
parent_data["parent_id"], |
|
new_node_id, |
|
relationship_type=parent_data["rel_type"] |
|
) |
|
|
|
return new_node_id |
|
|
|
|
|
async def add_as_child(node_id: str, query: str): |
|
if "SQ" in node_id: |
|
self.sub_node_counter += 1 |
|
new_node_id = f"SSQ{self.sub_node_counter}" |
|
else: |
|
self.node_counter += 1 |
|
new_node_id = f"SQ{self.node_counter}" |
|
|
|
self.add_node( |
|
node_id=new_node_id, |
|
query=query, |
|
role="dependent" |
|
) |
|
self.add_edge( |
|
node_id, |
|
new_node_id, |
|
relationship_type="logical" |
|
) |
|
|
|
return new_node_id |
|
|
|
|
|
def collect_graph_context() -> list: |
|
try: |
|
with self.transaction() as tx: |
|
|
|
result = tx.run(""" |
|
MATCH (n:Node) |
|
WHERE n.id <> $root_id AND n.graph_id = $graph_id |
|
WITH n |
|
ORDER BY |
|
CASE |
|
WHEN n.id STARTS WITH 'SQ' THEN 1 |
|
WHEN n.id STARTS WITH 'SSQ' THEN 2 |
|
ELSE 3 |
|
END, |
|
n.id |
|
RETURN COLLECT({ |
|
id: n.id, |
|
query: n.query, |
|
role: n.role |
|
}) as nodes |
|
""", root_id=self.root_node_id, graph_id=self.current_graph_id) |
|
|
|
nodes = result.single()["nodes"] |
|
if not nodes: |
|
return [] |
|
|
|
|
|
level_queries = {} |
|
current_sq = None |
|
|
|
for node in nodes: |
|
node_id = node["id"] |
|
if node_id.startswith("SQ"): |
|
current_sq = node_id |
|
if current_sq not in level_queries: |
|
level_queries[current_sq] = { |
|
"originalquery": node["query"], |
|
"subqueries": [] |
|
} |
|
|
|
|
|
level_queries[current_sq]["subqueries"].append({ |
|
"subquery": node["query"], |
|
"role": node["role"], |
|
"dependson": [] |
|
}) |
|
|
|
elif node_id.startswith("SSQ") and current_sq: |
|
level_queries[current_sq]["subqueries"].append({ |
|
"subquery": node["query"], |
|
"role": node["role"], |
|
"dependson": [] |
|
}) |
|
|
|
|
|
for sq_id, query_data in level_queries.items(): |
|
for i, sub_query in enumerate(query_data["subqueries"]): |
|
|
|
deps = tx.run(""" |
|
MATCH (n:Node {query: $node_query})-[r:RELATION {type: 'logical'}]->(m:Node) |
|
WHERE n.graph_id = $graph_id |
|
RETURN COLLECT(m.query) as dependencies |
|
""", node_query=sub_query["subquery"], graph_id=self.current_graph_id) |
|
|
|
dep_queries = deps.single()["dependencies"] |
|
if dep_queries: |
|
|
|
curr_deps = [] |
|
prev_deps = [] |
|
for dep_query in dep_queries: |
|
|
|
curr_idx = next( |
|
(idx for idx, sq in enumerate(query_data["subqueries"]) |
|
if sq["subquery"] == dep_query), |
|
None |
|
) |
|
if curr_idx is not None: |
|
curr_deps.append(curr_idx) |
|
else: |
|
|
|
for prev_idx, prev_data in enumerate(level_queries.values()): |
|
if dep_query in [sq["subquery"] for sq in prev_data["subqueries"]]: |
|
prev_deps.append(prev_idx) |
|
break |
|
|
|
query_data["subqueries"][i]["dependson"] = [prev_deps, curr_deps] |
|
|
|
|
|
return list(level_queries.values()) |
|
|
|
except Exception as e: |
|
print(f"Error collecting graph context: {str(e)}") |
|
raise |
|
|
|
try: |
|
|
|
with self.transaction() as tx: |
|
result = tx.run(""" |
|
MATCH (n:Node {id: $node_id}) |
|
RETURN n.role as role, |
|
n.query as query, |
|
EXISTS((n)<-[:RELATION]-()) as has_parent |
|
""", node_id=similar_node_id) |
|
|
|
node_data = result.single() |
|
if not node_data: |
|
raise Exception(f"Node {similar_node_id} not found") |
|
|
|
|
|
context = collect_graph_context() |
|
|
|
|
|
if node_data["role"] == "independent": |
|
|
|
if node_data["has_parent"]: |
|
new_node_id = await add_as_sibling(similar_node_id, new_query) |
|
else: |
|
new_node_id = await add_as_child(similar_node_id, new_query) |
|
else: |
|
|
|
new_node_id = await add_as_child(similar_node_id, new_query) |
|
|
|
|
|
await self.build_graph( |
|
query=new_query, |
|
parent_node_id=new_node_id, |
|
depth=1 if "SQ" in new_node_id else 2, |
|
context=context, |
|
session_id=session_id |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error modifying graph: {str(e)}") |
|
raise |
|
|
|
async def build_graph(self, query: str, data: str = None, parent_node_id: str = None, |
|
depth: int = 0, threshold: float = 0.8, recurse: bool = True, |
|
context: list = None, session_id: str = None, max_tokens_allowed: int = 128000): |
|
"""Build a new graph structure in Neo4j.""" |
|
|
|
async def process_node(self, node_id: str, sub_query: str, |
|
session_id: str, future: asyncio.Future, |
|
depth=depth, max_tokens_allowed=max_tokens_allowed): |
|
"""Process a node asynchronously.""" |
|
try: |
|
|
|
optimized_query = await self.search_engine.generate_optimized_query(sub_query) |
|
|
|
|
|
results = await self.search_engine.search( |
|
query=optimized_query, |
|
num_results=10, |
|
exclude_filetypes=["pdf"] |
|
) |
|
|
|
await self.emit_event("search_results_fetched", { |
|
"node_id": node_id, |
|
"sub_query": sub_query, |
|
"optimized_query": optimized_query, |
|
"search_results": results |
|
}) |
|
|
|
|
|
filtered_urls = await self.search_engine.filter_urls( |
|
sub_query, |
|
"ultra", |
|
results |
|
) |
|
|
|
await self.emit_event("search_results_filtered", { |
|
"node_id": node_id, |
|
"sub_query": sub_query, |
|
"filtered_urls": filtered_urls |
|
}) |
|
|
|
|
|
urls = [result.get('link', 'No URL') for result in filtered_urls] |
|
|
|
|
|
search_contents = await self.custom_crawler.fetch_page_contents( |
|
urls, |
|
sub_query, |
|
session_id=session_id, |
|
max_attempts=1, |
|
timeout=30 |
|
) |
|
|
|
await self.emit_event("search_contents_fetched", { |
|
"node_id": node_id, |
|
"sub_query": sub_query, |
|
"contents": search_contents |
|
}) |
|
|
|
|
|
contents = "" |
|
for k, content in enumerate(search_contents, 1): |
|
if isinstance(content, Exception): |
|
print(f"Error fetching content: {content}") |
|
elif content: |
|
contents += f"Document {k}:\n{content}\n\n" |
|
|
|
if len(contents.strip()) > 0: |
|
if depth == 0: |
|
|
|
await self.emit_event("sub_query_processed", { |
|
"node_id": node_id, |
|
"sub_query": sub_query, |
|
"contents": contents |
|
}) |
|
|
|
|
|
token_count = self.llm.get_num_tokens(contents) |
|
if token_count > max_tokens_allowed: |
|
contents = await self.chunking.chunker( |
|
text=contents, |
|
query=sub_query, |
|
max_tokens=max_tokens_allowed |
|
) |
|
print(f"Number of tokens in the answer: {token_count}") |
|
print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}") |
|
else: |
|
if depth == 0: |
|
|
|
await self.emit_event("sub_query_failed", { |
|
"node_id": node_id, |
|
"sub_query": sub_query, |
|
"contents": contents |
|
}) |
|
|
|
|
|
with self.transaction() as tx: |
|
tx.run( |
|
""" |
|
MATCH (n:Node {id: $node_id}) |
|
SET n.data = $data |
|
""", |
|
node_id=node_id, |
|
data=contents |
|
) |
|
|
|
|
|
future.set_result(contents) |
|
|
|
except Exception as e: |
|
print(f"Error processing node {node_id}: {str(e)}") |
|
if depth == 0: |
|
await self.emit_event("sub_query_failed", { |
|
"node_id": node_id, |
|
"sub_query": sub_query |
|
}) |
|
future.set_exception(e) |
|
raise |
|
|
|
async def process_dependent_node(self, node_id: str, sub_query: str, depth, dep_futures: list, future): |
|
"""Process a dependent node asynchronously.""" |
|
try: |
|
loop = asyncio.get_running_loop() |
|
|
|
|
|
dep_data = [await f for f in dep_futures] |
|
|
|
|
|
modified_query = await self.query_processor.modify_query( |
|
sub_query, |
|
dep_data |
|
) |
|
|
|
|
|
embedding = await loop.run_in_executor( |
|
self.executor, |
|
self.model.encode, |
|
modified_query |
|
) |
|
|
|
|
|
with self.transaction() as tx: |
|
tx.run( |
|
""" |
|
MATCH (n:Node {id: $node_id}) |
|
SET n.query = $modified_query, |
|
n.embedding = $embedding |
|
""", |
|
node_id=node_id, |
|
modified_query=modified_query, |
|
embedding=embedding.tolist() |
|
) |
|
|
|
|
|
try: |
|
if not future.done(): |
|
await process_node( |
|
self, node_id, modified_query, session_id, future, depth, max_tokens_allowed |
|
) |
|
except Exception as e: |
|
if depth == 0: |
|
await self.emit_event("sub_query_failed", { |
|
"node_id": node_id, |
|
"sub_query": sub_query |
|
}) |
|
if not future.done(): |
|
future.set_exception(e) |
|
raise |
|
|
|
except Exception as e: |
|
print(f"Error processing dependent node {node_id}: {str(e)}") |
|
if depth == 0: |
|
await self.emit_event("sub_query_failed", { |
|
"node_id": node_id, |
|
"sub_query": sub_query |
|
}) |
|
if not future.done(): |
|
future.set_exception(e) |
|
raise |
|
|
|
def create_cross_connections(self, node_id=None, depth=None, role=None): |
|
"""Create cross connections based on dependencies.""" |
|
try: |
|
|
|
relationships = self.get_node_relationships( |
|
node_id=node_id, |
|
depth=depth, |
|
role=role, |
|
relationship_type='logical' |
|
) |
|
|
|
for current_node_id, edges in relationships.items(): |
|
|
|
with self.transaction() as tx: |
|
result = tx.run( |
|
"MATCH (n:Node {id: $node_id}) RETURN n.role as role", |
|
node_id=current_node_id |
|
) |
|
node_data = result.single() |
|
if not node_data or not node_data["role"]: |
|
continue |
|
|
|
node_role = node_data["role"].lower() |
|
|
|
|
|
if node_role == 'dependent': |
|
|
|
for source_id, target_id, edge_data in edges['in_edges']: |
|
if not source_id or source_id == self.root_node_id: |
|
continue |
|
|
|
|
|
connection = tuple(sorted([current_node_id, source_id])) |
|
|
|
|
|
if connection not in self.cross_connections: |
|
if not self.edge_exists(source_id, current_node_id): |
|
print(f"Adding cross-connection edge between {source_id} and {current_node_id}") |
|
self.add_edge( |
|
source_id, |
|
current_node_id, |
|
weight=edge_data.get('weight', 1.0), |
|
relationship_type='logical' |
|
) |
|
self.cross_connections.add(connection) |
|
|
|
|
|
for source_id, target_id, edge_data in edges['out_edges']: |
|
if not target_id or target_id == self.root_node_id: |
|
continue |
|
|
|
|
|
connection = tuple(sorted([current_node_id, target_id])) |
|
|
|
|
|
if connection not in self.cross_connections: |
|
if not self.edge_exists(current_node_id, target_id): |
|
print(f"Adding cross-connection edge between {current_node_id} and {target_id}") |
|
self.add_edge( |
|
current_node_id, |
|
target_id, |
|
weight=edge_data.get('weight', 1.0), |
|
relationship_type='logical' |
|
) |
|
self.cross_connections.add(connection) |
|
|
|
except Exception as e: |
|
print(f"Error creating cross connections: {str(e)}") |
|
raise |
|
|
|
|
|
|
|
if depth > 1: |
|
return |
|
|
|
|
|
if context is None: |
|
context = [] |
|
|
|
|
|
node_data_futures = {} |
|
|
|
if parent_node_id is None: |
|
|
|
self.add_node(self.root_node_id, query, data) |
|
parent_node_id = self.root_node_id |
|
|
|
|
|
intent = await self.query_processor.get_query_intent(query) |
|
|
|
if depth == 0: |
|
|
|
response_data, sub_queries, roles, dependencies = \ |
|
await self.query_processor.decompose_query_with_dependencies(query, intent) |
|
else: |
|
|
|
response_data, sub_queries, roles, dependencies = \ |
|
await self.query_processor.decompose_query_with_dependencies( |
|
query, |
|
intent, |
|
context |
|
) |
|
|
|
|
|
if response_data: |
|
context.append(response_data) |
|
|
|
|
|
if len(sub_queries) > 1 and sub_queries[0] != query: |
|
sub_query_ids = [] |
|
pre_req_nodes = {} |
|
|
|
|
|
for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)): |
|
|
|
|
|
if depth == 0: |
|
await self.emit_event( |
|
"sub_query_created", |
|
{ |
|
"depth": depth, |
|
"sub_query": sub_query, |
|
"role": role, |
|
"dependency": dependency, |
|
"parent_node_id": parent_node_id, |
|
} |
|
) |
|
|
|
|
|
if depth == 0: |
|
self.node_counter += 1 |
|
sub_node_id = f"SQ{self.node_counter}" |
|
else: |
|
self.sub_node_counter += 1 |
|
sub_node_id = f"SSQ{self.sub_node_counter}" |
|
|
|
|
|
sub_query_ids.append(sub_node_id) |
|
|
|
|
|
self.add_node(node_id=sub_node_id, query=sub_query, role=role) |
|
|
|
|
|
future = asyncio.Future() |
|
node_data_futures[sub_node_id] = future |
|
|
|
if role.lower() in ('pre-requisite', 'prerequisite'): |
|
pre_req_nodes[idx] = sub_node_id |
|
|
|
|
|
if role.lower() in ('pre-requisite', 'prerequisite', 'independent'): |
|
|
|
self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical') |
|
elif role.lower() == 'dependent': |
|
if isinstance(dependency, list) and ( |
|
(len(dependency) == 2 and all(isinstance(d, list) for d in dependency)) |
|
): |
|
print(f"Dependency: {dependency}") |
|
|
|
prev_deps, current_deps = dependency |
|
|
|
|
|
if context and prev_deps not in [None, []]: |
|
for dep_idx in prev_deps: |
|
if dep_idx is not None: |
|
|
|
for context_data in context: |
|
if context_data and 'subqueries' in context_data: |
|
if dep_idx < len(context_data['subqueries']): |
|
|
|
sub_query_data = context_data['subqueries'][dep_idx] |
|
if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data: |
|
dep_query = sub_query_data['subquery'] |
|
|
|
matching_nodes = self.find_nodes_by_properties(query=dep_query) |
|
|
|
if matching_nodes not in [None, []]: |
|
dep_node_id = matching_nodes[0].get('node_id') |
|
score = matching_nodes[0].get('score', 0) |
|
if score >= 0.9: |
|
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') |
|
|
|
|
|
if current_deps not in [None, []]: |
|
for dep_idx in current_deps: |
|
if dep_idx < len(sub_queries): |
|
dep_node_id = sub_query_ids[dep_idx] |
|
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') |
|
else: |
|
|
|
raise ValueError(f"Invalid dependency index: {dep_idx}") |
|
elif len(dependency) > 0: |
|
for dep_idx in dependency: |
|
if dep_idx < len(sub_queries): |
|
|
|
dep_node_id = sub_query_ids[dep_idx] |
|
|
|
self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') |
|
else: |
|
raise ValueError(f"Invalid dependency index: {dep_idx}") |
|
else: |
|
|
|
raise ValueError(f"Invalid dependency: {dependency}") |
|
else: |
|
|
|
raise ValueError(f"Unexpected role: {role}") |
|
|
|
|
|
tasks = [] |
|
|
|
|
|
for idx in range(len(sub_queries)): |
|
node_id = sub_query_ids[idx] |
|
future = node_data_futures[node_id] |
|
if roles[idx].lower() in ('pre-requisite', 'prerequisite', 'independent'): |
|
tasks.append(process_node( |
|
self, node_id, sub_queries[idx], session_id, future, depth, max_tokens_allowed |
|
)) |
|
|
|
|
|
for idx in range(len(sub_queries)): |
|
node_id = sub_query_ids[idx] |
|
future = node_data_futures[node_id] |
|
|
|
if roles[idx].lower() == 'dependent': |
|
dep_futures = [] |
|
if isinstance(dependencies[idx], list) and len(dependencies[idx]) == 2: |
|
prev_deps, current_deps = dependencies[idx] |
|
|
|
|
|
if context and prev_deps not in [None, []]: |
|
for context_idx, context_data in enumerate(context): |
|
|
|
if isinstance(prev_deps, list) and context_idx < len(prev_deps): |
|
context_dep = prev_deps[context_idx] |
|
if context_dep is not None: |
|
if context_data and 'subqueries' in context_data: |
|
if context_dep < len(context_data['subqueries']): |
|
sub_query_data = context_data['subqueries'][context_dep] |
|
if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data: |
|
dep_query = sub_query_data['subquery'] |
|
|
|
matching_nodes = self.find_nodes_by_properties(query=dep_query) |
|
if matching_nodes not in [None, []]: |
|
|
|
dep_node_id = matching_nodes[0].get('node_id', None) |
|
score = float(matching_nodes[0].get('score', 0)) |
|
if score == 1.0 and dep_node_id in node_data_futures: |
|
dep_futures.append(node_data_futures[dep_node_id]) |
|
|
|
|
|
elif isinstance(prev_deps, int): |
|
if prev_deps < len(context_data['subqueries']): |
|
sub_query_data = context_data['subqueries'][prev_deps] |
|
if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data: |
|
dep_query = sub_query_data['subquery'] |
|
|
|
matching_nodes = self.find_nodes_by_properties(query=dep_query) |
|
if matching_nodes not in [None, []]: |
|
|
|
dep_node_id = matching_nodes[0].get('node_id', None) |
|
score = matching_nodes[0].get('score', 0) |
|
if score == 1.0 and dep_node_id in node_data_futures: |
|
dep_futures.append(node_data_futures[dep_node_id]) |
|
|
|
|
|
if current_deps not in [None, []]: |
|
current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps |
|
for dep_idx in current_deps_list: |
|
if dep_idx < len(sub_queries): |
|
dep_node_id = sub_query_ids[dep_idx] |
|
if dep_node_id in node_data_futures: |
|
dep_futures.append(node_data_futures[dep_node_id]) |
|
|
|
|
|
tasks.append(process_dependent_node( |
|
self, node_id, sub_queries[idx], depth, dep_futures, future |
|
)) |
|
|
|
|
|
if depth == 0: |
|
await self.emit_event("search_process_started", { |
|
"depth": depth, |
|
"sub_queries": sub_queries, |
|
"roles": roles |
|
}) |
|
|
|
|
|
await asyncio.gather(*tasks) |
|
|
|
|
|
if recurse: |
|
recursion_tasks = [] |
|
for idx, sub_query in enumerate(sub_queries): |
|
try: |
|
sub_node_id = sub_query_ids[idx] |
|
recursion_tasks.append( |
|
self.build_graph( |
|
query=sub_query, |
|
parent_node_id=sub_node_id, |
|
depth=depth + 1, |
|
threshold=threshold, |
|
recurse=recurse, |
|
context=context, |
|
session_id=session_id |
|
)) |
|
except Exception as e: |
|
print(f"Failed to create recursion task for sub-query {sub_query}: {e}") |
|
continue |
|
|
|
|
|
if recursion_tasks: |
|
try: |
|
await asyncio.gather(*recursion_tasks) |
|
except Exception as e: |
|
raise Exception(f"Error during recursive processing: {e}") |
|
|
|
|
|
if depth == 0: |
|
print("Graph building complete, processing final tasks...") |
|
await self.emit_event("search_process_completed", { |
|
"depth": depth, |
|
"sub_queries": sub_queries, |
|
"roles": roles |
|
}) |
|
|
|
|
|
create_cross_connections(self) |
|
print("All cross-connections have been created!") |
|
|
|
|
|
print(f"Adding similarity edges with threshold {threshold}") |
|
all_nodes = [] |
|
with self.driver.session() as session: |
|
result = session.run( |
|
"MATCH (n:Node) WHERE n.id <> $root_id RETURN n.id as id", |
|
root_id=self.root_node_id |
|
) |
|
all_nodes = [record["id"] for record in result] |
|
|
|
for i, node1 in enumerate(all_nodes): |
|
for node2 in all_nodes[i+1:]: |
|
if not self.edge_exists(node1, node2): |
|
self.add_edge_based_on_similarity_and_relevance( |
|
node1, node2, query, threshold |
|
) |
|
|
|
print("All similarity-based edges have been added!") |
|
|
|
async def process_graph( |
|
self, |
|
query: str, |
|
data: str = None, |
|
similarity_threshold: float = 0.8, |
|
relevance_threshold: float = 0.7, |
|
sub_sub_queries: bool = True, |
|
session_id: str = None, |
|
max_tokens_allowed: int = 128000 |
|
): |
|
"""Process a query and manage graph creation/modification.""" |
|
|
|
|
|
def check_query_similarity(new_query: str, similarity_threshold: float = 0.8) -> Dict[str, Any]: |
|
if self.current_graph_id is None: |
|
raise Exception("Error: No current graph ID. Cannot check query similarity.") |
|
|
|
try: |
|
|
|
print(f"Retrieving existing queries and their metadata for graph {self.current_graph_id}") |
|
with self.transaction() as tx: |
|
result = tx.run(""" |
|
MATCH (n:Node) |
|
WHERE n.graph_id IS NOT NULL |
|
AND n.graph_id = $graph_id |
|
RETURN n.id as id, |
|
n.query as query, |
|
n.role as role |
|
""", |
|
graph_id=self.current_graph_id |
|
) |
|
|
|
|
|
similarities = [] |
|
records = list(result) |
|
|
|
if records == []: |
|
return {"should_create_new": True} |
|
|
|
for record in records: |
|
|
|
if not all([record["query"]]): |
|
continue |
|
|
|
|
|
similarity = self.calculate_query_similarity( |
|
new_query, |
|
record["query"] |
|
) |
|
|
|
if similarity >= similarity_threshold: |
|
similarities.append({ |
|
"node_id": record["id"], |
|
"query": record["query"], |
|
"score": similarity, |
|
"role": record["role"] |
|
}) |
|
|
|
|
|
if similarities == []: |
|
print(f"No similar queries found above threshold {similarity_threshold}") |
|
return {"should_create_new": True} |
|
|
|
|
|
best_match = max(similarities, key=lambda x: x["score"]) |
|
|
|
|
|
rel_type = "root" |
|
if "SSQ" in best_match["node_id"]: |
|
rel_type = "sub-sub" |
|
elif "SQ" in best_match["node_id"]: |
|
rel_type = "sub" |
|
|
|
return { |
|
"most_similar_query": best_match["query"], |
|
"similarity_score": best_match["score"], |
|
"relationship_type": rel_type, |
|
"node_id": best_match["node_id"], |
|
"should_create_new": best_match["score"] < similarity_threshold |
|
} |
|
|
|
except Exception as e: |
|
print(f"Error checking query similarity: {str(e)}") |
|
raise |
|
|
|
try: |
|
|
|
print("Checking for existing graphs...") |
|
result = self.get_graphs() |
|
graphs = list(result) |
|
|
|
if graphs == []: |
|
print("No existing graphs found. Creating new graph.") |
|
self.create_new_graph() |
|
|
|
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"}) |
|
await self.build_graph( |
|
query=query, |
|
data=data, |
|
threshold=relevance_threshold, |
|
recurse=sub_sub_queries, |
|
session_id=session_id, |
|
max_tokens_allowed=max_tokens_allowed |
|
) |
|
|
|
gc.collect() |
|
|
|
|
|
self.prune_edges() |
|
self.update_pagerank() |
|
|
|
|
|
self.verify_graph_integrity() |
|
self.verify_graph_consistency() |
|
return |
|
|
|
|
|
max_similarity = 0 |
|
most_similar_graph = None |
|
|
|
|
|
consolidated_graphs = {} |
|
for graph in graphs: |
|
graph_info = graph.get("graph_info") |
|
if not graph_info: |
|
continue |
|
|
|
graph_id = graph_info.get("graph_id") |
|
if not graph_id: |
|
continue |
|
|
|
|
|
if graph_id not in consolidated_graphs: |
|
consolidated_graphs[graph_id] = { |
|
"graph_id": graph_id, |
|
"nodes": [] |
|
} |
|
|
|
|
|
if graph_info.get("nodes"): |
|
consolidated_graphs[graph_id]["nodes"].extend(graph_info["nodes"]) |
|
|
|
|
|
for graph_id, graph_data in consolidated_graphs.items(): |
|
nodes = graph_data["nodes"] |
|
|
|
|
|
for node in nodes: |
|
if node.get("query"): |
|
similarity = self.calculate_query_similarity( |
|
query, |
|
node["query"] |
|
) |
|
|
|
if node.get("id").startswith("SQ"): |
|
await self.emit_event("retrieved_sub_query", { |
|
"sub_query": node["query"] |
|
}) |
|
|
|
if similarity > max_similarity: |
|
max_similarity = similarity |
|
most_similar_graph = graph_id |
|
|
|
if max_similarity >= similarity_threshold: |
|
|
|
print(f"Found similar query with score {round(max_similarity, 2)}") |
|
self.current_graph_id = most_similar_graph |
|
|
|
if round(max_similarity, 2) == 1.0: |
|
print("Loading and using existing graph") |
|
|
|
await self.emit_event("graph_operation", {"operation_type": "loading_existing_graph"}) |
|
success = self.load_graph(self.root_node_id) |
|
if not success: |
|
raise Exception("Failed to load existing graph") |
|
else: |
|
|
|
print("Checking for node-level similarity...") |
|
similarity_info = check_query_similarity( |
|
query, |
|
similarity_threshold |
|
) |
|
|
|
if similarity_info["relationship_type"] in ["sub", "sub-sub"]: |
|
print(f"Most Similar Query: {similarity_info['most_similar_query']}") |
|
print("Modifying existing graph structure") |
|
|
|
await self.emit_event("graph_operation", {"operation_type": "modifying_existing_graph"}) |
|
await self.modify_graph( |
|
query, |
|
similarity_info["node_id"], |
|
session_id=session_id |
|
) |
|
|
|
gc.collect() |
|
|
|
|
|
self.prune_edges() |
|
self.update_pagerank() |
|
|
|
|
|
self.verify_graph_integrity() |
|
self.verify_graph_consistency() |
|
else: |
|
|
|
print(f"Creating new graph for query: {query}") |
|
self.create_new_graph() |
|
|
|
await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"}) |
|
await self.build_graph( |
|
query=query, |
|
data=data, |
|
threshold=relevance_threshold, |
|
recurse=sub_sub_queries, |
|
session_id=session_id, |
|
max_tokens_allowed=max_tokens_allowed |
|
) |
|
|
|
gc.collect() |
|
|
|
|
|
self.prune_edges() |
|
self.update_pagerank() |
|
|
|
|
|
self.verify_graph_integrity() |
|
self.verify_graph_consistency() |
|
|
|
except Exception as e: |
|
print(f"Error in process_graph: {str(e)}") |
|
raise |
|
|
|
def add_edge_based_on_similarity_and_relevance(self, node1_id: str, node2_id: str, query: str, threshold: float = 0.8): |
|
"""Add edges based on node similarity and relevance.""" |
|
try: |
|
with self.transaction() as tx: |
|
|
|
result = tx.run( |
|
""" |
|
MATCH (n1:Node {id: $node1_id}) |
|
WITH n1 |
|
MATCH (n2:Node {id: $node2_id}) |
|
RETURN n1.embedding as emb1, n1.data as data1, |
|
n2.embedding as emb2, n2.data as data2 |
|
""", |
|
node1_id=node1_id, |
|
node2_id=node2_id |
|
) |
|
data = result.single() |
|
if not data or not all([data["emb1"], data["emb2"], data["data1"], data["data2"]]): |
|
return |
|
|
|
|
|
similarity = self.cosine_similarity(data["emb1"], data["emb2"]) |
|
query_relevance1 = self.calculate_relevance(query, data["data1"]) |
|
query_relevance2 = self.calculate_relevance(query, data["data2"]) |
|
node_relevance = self.calculate_relevance(data["data1"], data["data2"]) |
|
|
|
|
|
weight = (similarity + query_relevance1 + query_relevance2 + node_relevance) / 4 |
|
|
|
|
|
if weight >= threshold: |
|
tx.run( |
|
""" |
|
MATCH (a:Node {id: $node1_id}), (b:Node {id: $node2_id}) |
|
MERGE (a)-[r:RELATION {type: 'similarity_and_relevance'}]->(b) |
|
ON CREATE SET r.weight = $weight |
|
ON MATCH SET r.weight = $weight |
|
""", |
|
node1_id=node1_id, |
|
node2_id=node2_id, |
|
weight=weight |
|
) |
|
print(f"Added edge between {node1_id} and {node2_id} with type similarity_and_relevance and weight {weight}") |
|
|
|
except Exception as e: |
|
print(f"Error in similarity edge creation between {node1_id} and {node2_id}: {str(e)}") |
|
raise |
|
|
|
def calculate_relevance(self, data1: str, data2: str) -> float: |
|
"""Calculate relevance between two data.""" |
|
try: |
|
if not data1 or not data2: |
|
return 0.0 |
|
|
|
P, R, F1 = self.scorer.score([data1], [data2]) |
|
return F1.mean().item() |
|
|
|
except Exception as e: |
|
print(f"Error calculating relevance: {str(e)}") |
|
return 0.0 |
|
|
|
def calculate_query_similarity(self, query1: str, query2: str) -> float: |
|
"""Calculate similarity between two queries.""" |
|
try: |
|
|
|
embedding1 = self.model.encode(query1).tolist() |
|
embedding2 = self.model.encode(query2).tolist() |
|
|
|
|
|
return self.cosine_similarity(embedding1, embedding2) |
|
|
|
except Exception as e: |
|
print(f"Error calculating query similarity: {str(e)}") |
|
return 0.0 |
|
|
|
def get_similarities_and_relevance(self, threshold: float = 0.8) -> list: |
|
"""Get similarities and relevance between nodes.""" |
|
try: |
|
with self.transaction() as tx: |
|
|
|
result = tx.run( |
|
""" |
|
MATCH (n:Node) |
|
WHERE n.id <> $root_id |
|
RETURN n.id as id, n.embedding as embedding, n.data as data |
|
""", |
|
root_id=self.root_node_id |
|
) |
|
|
|
nodes = list(result) |
|
similarities = [] |
|
|
|
|
|
for i, node1 in enumerate(nodes): |
|
for node2 in nodes[i + 1:]: |
|
similarity = self.cosine_similarity(node1["embedding"], node2["embedding"]) |
|
relevance = self.calculate_relevance(node1["data"], node2["data"]) |
|
|
|
|
|
weight = (similarity + relevance) / 2 |
|
|
|
|
|
if weight >= threshold: |
|
similarities.append({ |
|
'node1': node1["id"], |
|
'node2': node2["id"], |
|
'similarity': similarity, |
|
'relevance': relevance, |
|
'weight': weight |
|
}) |
|
|
|
return similarities |
|
|
|
except Exception as e: |
|
print(f"Error getting similarities and relevance: {str(e)}") |
|
return [] |
|
|
|
def get_node_relationships(self, node_id=None, depth=None, role=None, relationship_type=None): |
|
"""Get relationships between nodes with filtering options.""" |
|
try: |
|
with self.transaction() as tx: |
|
|
|
cypher_query = """ |
|
MATCH (n:Node) |
|
WHERE n.id <> $root_id |
|
AND n.graph_id = $current_graph_id |
|
""" |
|
params = { |
|
"root_id": self.root_node_id, |
|
"current_graph_id": self.current_graph_id |
|
} |
|
|
|
|
|
if node_id: |
|
cypher_query += " AND n.id = $node_id" |
|
params["node_id"] = node_id |
|
if role: |
|
cypher_query += " AND n.role = $role" |
|
params["role"] = role |
|
if depth is not None: |
|
cypher_query += " AND n.depth = $depth" |
|
params["depth"] = depth |
|
|
|
|
|
cypher_query += """ |
|
WITH n |
|
OPTIONAL MATCH (n)-[r1:RELATION]->(m1:Node) |
|
WHERE m1.id <> $root_id |
|
AND m1.graph_id = $current_graph_id |
|
""" |
|
|
|
|
|
if relationship_type: |
|
cypher_query += " AND r1.type = $rel_type" |
|
params["rel_type"] = relationship_type |
|
|
|
|
|
cypher_query += """ |
|
WITH n, collect({source: n.id, target: m1.id, weight: r1.weight, type: r1.type}) as out_edges |
|
OPTIONAL MATCH (n)<-[r2:RELATION]-(m2:Node) |
|
WHERE m2.id <> $root_id |
|
AND m2.graph_id = $current_graph_id |
|
""" |
|
|
|
|
|
if relationship_type: |
|
cypher_query += " AND r2.type = $rel_type" |
|
|
|
|
|
cypher_query += """ |
|
RETURN n.id as node_id, |
|
collect({source: m2.id, target: n.id, weight: r2.weight, type: r2.type}) as in_edges, |
|
out_edges |
|
""" |
|
|
|
result = tx.run(cypher_query, params) |
|
relationships = {} |
|
|
|
for record in result: |
|
node_id = record["node_id"] |
|
relationships[node_id] = { |
|
'in_edges': [(edge['source'], edge['target'], { |
|
'weight': edge['weight'], |
|
'type': edge['type'] |
|
}) for edge in record["in_edges"] if edge['source'] is not None], |
|
'out_edges': [(edge['source'], edge['target'], { |
|
'weight': edge['weight'], |
|
'type': edge['type'] |
|
}) for edge in record["out_edges"] if edge['target'] is not None] |
|
} |
|
|
|
return relationships |
|
|
|
except Exception as e: |
|
print(f"Error getting node relationships: {str(e)}") |
|
raise |
|
|
|
def find_nodes_by_properties(self, query: str = None, embedding: list = None, |
|
node_data: dict = None, similarity_threshold: float = 0.8) -> list: |
|
"""Find nodes based on properties.""" |
|
try: |
|
with self.transaction() as tx: |
|
match_conditions = [] |
|
where_conditions = [] |
|
params = {} |
|
|
|
|
|
if query: |
|
where_conditions.append("n.query CONTAINS $node_query") |
|
params["node_query"] = query |
|
|
|
if node_data: |
|
for key, value in node_data.items(): |
|
where_conditions.append(f"n.{key} = ${key}") |
|
params[key] = value |
|
|
|
|
|
cypher_query = "MATCH (n:Node)" |
|
if where_conditions: |
|
cypher_query += " WHERE " + " AND ".join(where_conditions) |
|
cypher_query += " RETURN n" |
|
|
|
result = tx.run(cypher_query, params) |
|
matching_nodes = [] |
|
|
|
|
|
for record in result: |
|
node = record["n"] |
|
match_score = 0 |
|
matches = 0 |
|
|
|
|
|
if query and query.lower() in node["query"].lower(): |
|
match_score += 1 |
|
matches += 1 |
|
|
|
|
|
if embedding and "embedding" in node: |
|
similarity = self.cosine_similarity(embedding, node["embedding"]) |
|
if similarity >= similarity_threshold: |
|
match_score += similarity |
|
matches += 1 |
|
|
|
|
|
if node_data: |
|
data_matches = sum(1 for k, v in node_data.items() |
|
if k in node and node[k] == v) |
|
if data_matches > 0: |
|
match_score += data_matches / len(node_data) |
|
matches += 1 |
|
|
|
|
|
if matches > 0: |
|
matching_nodes.append({ |
|
"node_id": node["id"], |
|
"score": match_score / matches, |
|
"data": dict(node) |
|
}) |
|
|
|
|
|
matching_nodes.sort(key=lambda x: x["score"], reverse=True) |
|
return matching_nodes |
|
|
|
except Exception as e: |
|
print(f"Error finding nodes by properties: {str(e)}") |
|
raise |
|
|
|
def query_graph(self, query: str) -> str: |
|
"""Query the graph in Neo4j for a specific query, collecting data from the entire relevant subgraph.""" |
|
try: |
|
with self.transaction() as tx: |
|
|
|
query_node = tx.run(""" |
|
MATCH (n:Node {query: $node_query}) |
|
WHERE n.graph_id = $graph_id |
|
RETURN n |
|
""", node_query=query, graph_id=self.current_graph_id).single() |
|
|
|
if not query_node: |
|
raise ValueError(f"Query node not found for: {query}") |
|
|
|
query_node_id = query_node['n']['id'] |
|
datas = [] |
|
|
|
|
|
subgraph_paths = tx.run(""" |
|
// First get the query node and all its connected paths |
|
MATCH path = (n:Node {id: $node_id})-[r:RELATION*0..]->(m:Node) |
|
WHERE n.graph_id = $graph_id |
|
|
|
// Collect all nodes and relationships in these paths |
|
WITH COLLECT(path) as paths |
|
UNWIND paths as path |
|
WITH DISTINCT path |
|
|
|
// Get all nodes and relationships from the paths |
|
WITH nodes(path) as nodes, relationships(path) as rels |
|
|
|
// Calculate path weight considering all relationship types |
|
WITH nodes, rels, |
|
reduce(weight = 1.0, rel in rels | |
|
CASE rel.type |
|
WHEN 'logical' THEN weight * rel.weight * 1.2 |
|
WHEN 'hierarchical' THEN weight * rel.weight * 1.1 |
|
WHEN 'similarity_and_relevance' THEN weight * rel.weight * 0.9 |
|
ELSE weight * rel.weight |
|
END |
|
) as path_weight |
|
|
|
// Unwind nodes to get individual records |
|
UNWIND nodes as node |
|
WITH DISTINCT node, path_weight |
|
WHERE node.data IS NOT NULL |
|
AND node.data <> '' // Ensure data is not empty |
|
|
|
// Return ordered by weight and pagerank for better context flow |
|
RETURN node.data as data, |
|
path_weight, |
|
node.role as role, |
|
node.pagerank as pagerank |
|
ORDER BY |
|
CASE node.role |
|
WHEN 'pre-requisite' THEN 3 |
|
WHEN 'independent' THEN 2 |
|
ELSE 1 |
|
END DESC, |
|
path_weight DESC, |
|
pagerank DESC |
|
""", node_id=query_node_id, graph_id=self.current_graph_id) |
|
|
|
|
|
for record in subgraph_paths: |
|
data = record["data"] |
|
if data and isinstance(data, str): |
|
datas.append(data.strip()) |
|
|
|
|
|
if datas == []: |
|
print(f"No data found for: {query}") |
|
return "" |
|
|
|
|
|
return "\n\n".join([f"Data {i+1}:\n{data}" for i, data in enumerate(datas)]) |
|
|
|
except Exception as e: |
|
print(f"Error querying graph for specific query: {str(e)}") |
|
raise |
|
|
|
def prune_edges(self, max_edges: int = 1000): |
|
"""Prune excess edges while preserving node data.""" |
|
try: |
|
print(f"Pruning edges to keep top {max_edges} edges by weight...") |
|
with self.transaction() as tx: |
|
try: |
|
|
|
result = tx.run( |
|
""" |
|
MATCH (a:Node {graph_id: $graphID})-[r:RELATION]->(b:Node {graph_id: $graphID}) |
|
RETURN count(r) AS count |
|
""", |
|
graphID=self.current_graph_id |
|
) |
|
current_edges = result.single()["count"] |
|
|
|
if current_edges > max_edges: |
|
|
|
tx.run( |
|
""" |
|
MATCH (a:Node {graph_id: $graphID})-[r:RELATION]->(b:Node {graph_id: $graphID}) |
|
WITH r |
|
ORDER BY r.weight DESC |
|
LIMIT $max_edges |
|
SET r:KEEP |
|
""", |
|
graphID=self.current_graph_id, |
|
max_edges=max_edges |
|
) |
|
|
|
|
|
tx.run( |
|
""" |
|
MATCH (a:Node {graph_id: $graphID})-[r:RELATION]->(b:Node {graph_id: $graphID}) |
|
WHERE NOT r:KEEP |
|
DELETE r |
|
""", |
|
graphID=self.current_graph_id |
|
) |
|
|
|
|
|
tx.run( |
|
""" |
|
MATCH (a:Node {graph_id: $graphID})-[r:KEEP]->(b:Node {graph_id: $graphID}) |
|
REMOVE r:KEEP |
|
""", |
|
graphID=self.current_graph_id |
|
) |
|
|
|
tx.commit() |
|
print(f"Pruned edges. Kept top {max_edges} edges by weight.") |
|
|
|
print("No pruning needed. Current edge count is within limits.") |
|
|
|
except Exception as e: |
|
tx.rollback() |
|
raise e |
|
|
|
except Exception as e: |
|
print(f"Error pruning edges: {str(e)}") |
|
raise |
|
|
|
def update_pagerank(self): |
|
"""Update PageRank values using Neo4j's graph algorithms.""" |
|
if not self.current_graph_id: |
|
print("No current graph selected. Cannot compute PageRank.") |
|
return |
|
|
|
try: |
|
with self.transaction() as tx: |
|
|
|
tx.run( |
|
""" |
|
CALL gds.graph.project.cypher( |
|
'graphProjection', |
|
'MATCH (n:Node) WHERE n.graph_id = $myParam RETURN id(n) AS id', |
|
'MATCH (n:Node)-[r:RELATION]->(m:Node) |
|
WHERE n.graph_id = $myParam AND m.graph_id = $myParam |
|
RETURN id(n) AS source, |
|
id(m) AS target, |
|
CASE r.type |
|
WHEN "logical" THEN r.weight * 2 |
|
ELSE r.weight |
|
END AS weight', |
|
{ parameters: { myParam: $graphId } } |
|
) |
|
""", |
|
graphId=self.current_graph_id |
|
) |
|
|
|
|
|
tx.run( |
|
""" |
|
CALL gds.pageRank.write( |
|
'graphProjection', |
|
{ |
|
relationshipWeightProperty: 'weight', |
|
writeProperty: 'pagerank', |
|
maxIterations: 20, |
|
dampingFactor: 0.85, |
|
concurrency: 4 |
|
} |
|
) |
|
""" |
|
) |
|
|
|
|
|
tx.run( |
|
""" |
|
CALL gds.graph.drop('graphProjection') |
|
""" |
|
) |
|
|
|
print("PageRank updated successfully") |
|
|
|
except Exception as e: |
|
print(f"Error updating PageRank: {str(e)}") |
|
raise |
|
|
|
def display_graph(self, query: str): |
|
"""Display the graph""" |
|
try: |
|
with self.transaction() as tx: |
|
|
|
cypher_query = """ |
|
MATCH (n:Node) |
|
WHERE n.query = $node_query |
|
RETURN COLLECT(DISTINCT n.graph_id) AS graph_ids |
|
""" |
|
result = tx.run(cypher_query, node_query=query) |
|
graph_ids = result.single().get("graph_ids", []) |
|
|
|
if not graph_ids: |
|
print("No graph found for the given query.") |
|
return |
|
|
|
|
|
net = Network( |
|
height="600px", |
|
width="100%", |
|
directed=True, |
|
bgcolor="#222222", |
|
font_color="white" |
|
) |
|
|
|
|
|
net.options = {"physics": {"enabled": False}} |
|
|
|
all_nodes = set() |
|
all_edges = [] |
|
|
|
for graph_id in graph_ids: |
|
|
|
result = tx.run(f"MATCH (n)-[r]->(m) WHERE n.graph_id = '{graph_id}' RETURN n, r, m") |
|
|
|
for record in result: |
|
source_node = record["n"] |
|
target_node = record["m"] |
|
relationship = record["r"] |
|
|
|
source_id = source_node.get("id") |
|
target_id = target_node.get("id") |
|
|
|
|
|
source_tooltip = ( |
|
f"Query: {source_node.get('query', 'N/A')}" |
|
) |
|
|
|
target_tooltip = ( |
|
f"Query: {target_node.get('query', 'N/A')}" |
|
) |
|
|
|
|
|
if source_id not in all_nodes: |
|
net.add_node( |
|
source_id, |
|
label=source_id, |
|
title=source_tooltip, |
|
size=20, |
|
color="#00cc66" |
|
) |
|
all_nodes.add(source_id) |
|
|
|
|
|
if target_id not in all_nodes: |
|
net.add_node( |
|
target_id, |
|
label=target_id, |
|
title=target_tooltip, |
|
size=20, |
|
color="#00cc66" |
|
) |
|
all_nodes.add(target_id) |
|
|
|
|
|
all_edges.append({ |
|
"from": source_id, |
|
"to": target_id, |
|
"label": relationship.type, |
|
}) |
|
|
|
|
|
for edge in all_edges: |
|
net.add_edge( |
|
edge["from"], |
|
edge["to"], |
|
title=edge["label"], |
|
color="#cccccc" |
|
) |
|
|
|
|
|
net.options["layout"] = {"improvedLayout": True} |
|
net.options["interaction"] = {"dragNodes": True} |
|
|
|
|
|
net.save_graph("temp_graph.html") |
|
with open("temp_graph.html", "r", encoding="utf-8") as f: |
|
html_str = f.read() |
|
|
|
os.remove("temp_graph.html") |
|
|
|
return html_str |
|
|
|
except Exception as e: |
|
print(f"Error displaying graph: {str(e)}") |
|
raise |
|
|
|
def verify_graph_integrity(self): |
|
"""Verify and fix graph integrity issues.""" |
|
try: |
|
with self.transaction() as tx: |
|
|
|
orphaned = tx.run( |
|
""" |
|
MATCH (n:Node {graph_id: $graph_id}) |
|
WHERE NOT (n)-[:RELATION]-() |
|
RETURN n.id as node_id |
|
""", |
|
graph_id=self.current_graph_id |
|
).values() |
|
|
|
if orphaned: |
|
print(f"Found orphaned nodes: {orphaned}") |
|
|
|
|
|
invalid_edges = tx.run( |
|
""" |
|
MATCH (a:Node)-[r:RELATION]->(b:Node) |
|
WHERE a.graph_id = $graph_id |
|
AND (b.graph_id <> $graph_id OR b.graph_id IS NULL) |
|
RETURN a.id as from_id, b.id as to_id |
|
""", |
|
graph_id=self.current_graph_id |
|
).values() |
|
|
|
if invalid_edges: |
|
print(f"Found invalid edges: {invalid_edges}") |
|
|
|
tx.run( |
|
""" |
|
MATCH (a:Node)-[r:RELATION]->(b:Node) |
|
WHERE a.graph_id = $graph_id |
|
AND (b.graph_id <> $graph_id OR b.graph_id IS NULL) |
|
DELETE r |
|
""", |
|
graph_id=self.current_graph_id |
|
) |
|
|
|
print("Graph integrity verified successfully") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"Error verifying graph integrity: {str(e)}") |
|
raise |
|
|
|
def verify_graph_consistency(self): |
|
"""Verify consistency of the Neo4j graph.""" |
|
try: |
|
with self.driver.session() as session: |
|
|
|
missing_props = session.run(""" |
|
MATCH (n:Node) |
|
WHERE n.id IS NULL OR n.query IS NULL |
|
RETURN count(n) as count |
|
""") |
|
|
|
if missing_props.single()["count"] > 0: |
|
raise ValueError("Found nodes with missing required properties") |
|
|
|
|
|
invalid_rels = session.run(""" |
|
MATCH ()-[r:RELATION]->() |
|
WHERE r.type IS NULL OR r.weight IS NULL |
|
RETURN count(r) as count |
|
""") |
|
|
|
if invalid_rels.single()["count"] > 0: |
|
raise ValueError("Found relationships with missing required properties") |
|
|
|
print("Graph consistency verified successfully") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"Error verifying graph consistency: {str(e)}") |
|
raise |
|
|
|
async def close(self): |
|
"""Properly cleanup all resources.""" |
|
try: |
|
|
|
if hasattr(self, 'executor'): |
|
self.executor.shutdown(wait=True) |
|
|
|
|
|
if hasattr(self, 'driver'): |
|
self.driver.close() |
|
|
|
|
|
if hasattr(self, 'crawler'): |
|
await asyncio.shield(self.crawler.cleanup_expired_sessions()) |
|
await asyncio.shield(self.crawler.cleanup_browser_context(self.session_id)) |
|
|
|
except Exception as e: |
|
print(f"Error during cleanup: {e}") |
|
|
|
@staticmethod |
|
def cosine_similarity(v1: List[float], v2: List[float]) -> float: |
|
"""Calculate cosine similarity between two vectors.""" |
|
try: |
|
v1_array = np.array(v1) |
|
v2_array = np.array(v2) |
|
return np.dot(v1_array, v2_array) / (np.linalg.norm(v1_array) * np.linalg.norm(v2_array)) |
|
|
|
except Exception as e: |
|
print(f"Error calculating cosine similarity: {str(e)}") |
|
return 0.0 |
|
|
|
if __name__ == "__main__": |
|
import os |
|
from dotenv import load_dotenv |
|
from src.reasoning.reasoner import Reasoner |
|
from src.evaluation.evaluator import Evaluator |
|
|
|
load_dotenv() |
|
|
|
graph_search = Neo4jGraphRAG(num_workers=24) |
|
evaluator = Evaluator() |
|
reasoner = Reasoner() |
|
|
|
async def test_graph_search(): |
|
|
|
queries = [ |
|
"""In the context of global economic recovery and energy security concerns, provide an in-depth comparative assessment of the renewable energy policies among G20 countries. |
|
Specifically, examine how short-term economic stimulus measures intersect with long-term decarbonization commitments, including: |
|
1. Carbon pricing mechanisms |
|
2. Subsidies for emerging technologies (such as green hydrogen and battery storage) |
|
3. Cross-border climate finance initiatives |
|
|
|
Highlight the unique challenges faced by both advanced and emerging economies in addressing: |
|
1. Energy poverty |
|
2. Supply chain disruptions |
|
3. Geopolitical tensions (e.g., the Russia-Ukraine conflict) |
|
|
|
Discuss how these factors influence policy effectiveness, and evaluate the degree to which each country is on track to meet—or exceed—its Paris Agreement targets. |
|
Note any significant policy gaps, regional collaborations, or innovative best practices. |
|
Lastly, provide a forward-looking perspective on how these renewable energy strategies may evolve over the next decade, considering: |
|
1. Technological breakthroughs |
|
2. Global market trends |
|
3. Potential climate-related disasters |
|
|
|
Present your analysis as a detailed, well-formatted report.""", |
|
"""Analyse the impact of 'hot-money' on the value of Indian Rupee and answer the following questions:- |
|
1. How does it affect the exchange rate? |
|
2. How can it be mitigated/eliminated? |
|
3. Why is it a problem? |
|
4. What are the consequences? |
|
5. What are the alternatives? |
|
- Evaluate the alternatives for pros and cons. |
|
- Evaluate the impact of alternatives on the exchange rate. |
|
- How can they be implemented? |
|
- What are the consequences of each alternative? |
|
- Evaluate the feasibility of the alternatives. |
|
- Pick top 5 alternatives and justify your choices in detail. |
|
6. What are the implications for the Indian economy? Furthermore:- |
|
- Evaluate the impact of the chosen alternatives on the Indian economy.""", |
|
"""Inflation has been an intrinsic past of human civilization since the very beginning. Answer the following questions:- |
|
1. How true is the above statement? |
|
2. What are the causes of inflation? |
|
3. What are the consequences of inflation? |
|
4. Can we completely eliminate inflation?""", |
|
"""Perform a detailed comparison between the ancient Greece and Roman civilizations. |
|
1. What were the key differences between the two civilizations? |
|
- Evaluate the differences in governance, society, and culture |
|
- Evaluate the differences in economy, trade, and military |
|
- Evaluate the differences in technology and infrastructure |
|
2. What were the similarities between the two civilizations? |
|
- Evaluate the similarities in governance, society, and culture |
|
- Evaluate the similarities in economy, trade, and military |
|
- Evaluate the similarities in technology and infrastructure |
|
3. How did these two civilizations influence each other? |
|
- Evaluate the influence of one civilization on the other |
|
4. How did these two civilizations influence the modern world? |
|
5. Was there another civilization that influenced these two? If yes, how?""", |
|
"""Evaluate the long-term effects of colonialism on economic development in Asia:- |
|
1. Include case studies of at least five different countries |
|
2. Analyze how these effects differ based on colonial power, time of independence, and resource distribution |
|
- Evaluate the impact of colonialism on the economy of the country |
|
- Evaluate the impact of colonialism on the economy of the region |
|
- Evaluate the impact of colonialism on the economy of the world |
|
3. How do these effects compare to Africa?""" |
|
] |
|
follow_on_queries = [ |
|
"How is 'hot-money' related to the current economic situation in India?", |
|
"What is inflation?", |
|
"Did ancient Greece and Rome have any impact on modern democracy? If yes, how?", |
|
"Did colonialism have any impact on the trade between Africa and Asia, both in colonial and post-colonial times? If yes, how?" |
|
] |
|
|
|
query = queries[2] |
|
|
|
|
|
graph_search.initialize_schema() |
|
|
|
|
|
await graph_search.process_graph(query, similarity_threshold=0.8, relevance_threshold=0.8) |
|
|
|
|
|
answer = graph_search.query_graph(query) |
|
response = "" |
|
async for chunk in reasoner.answer(query, answer): |
|
response += chunk |
|
print(response, end="", flush=True) |
|
|
|
|
|
graph_search.display_graph(query) |
|
|
|
|
|
evaluation = await evaluator.evaluate_response(query, response, [answer]) |
|
print(f"Faithfulness: {evaluation['faithfulness']}") |
|
print(f"Answer Relevancy: {evaluation['answer relevancy']}") |
|
print(f"Context Utilization: {evaluation['contextual recall']}") |
|
|
|
|
|
await graph_search.close() |
|
|
|
|
|
asyncio.run(test_graph_search()) |