ai / knowledge.py
kevinhug's picture
knowledge graph
6093500
raw
history blame
2.31 kB
from openai import OpenAI
import instructor
from pydantic import BaseModel, Field
from typing import List
from graphviz import Digraph
class Node(BaseModel, frozen=True):
"""
Node representing concept in the subject domain
"""
id: int
label: str = Field(..., description = "description of the concept concept in the subject domain")
color: str
class Edge(BaseModel, frozen=True):
"""
Edge representing relationship between concepts in the subject domain
"""
source: int = Field(..., description = "source representing concept in the subject domain")
target: int = Field(..., description = "target representing concept in the subject domain")
label: str = Field(..., description = "description representing relationship between concepts in the subject domain")
color: str = "black"
class KnowledgeGraph(BaseModel):
"""
graph representation of subject domain
"""
nodes: List[Node] = Field(..., default_factory=list)
edges: List[Edge] = Field(..., default_factory=list)
from groq import Groq
import os
# Initialize with API key
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
# Enable instructor patches for Groq client
client = instructor.from_groq(client)
"""
client = instructor.from_openai(
OpenAI(
base_url="http://localhost:11434/v1",
api_key="ollama",
),
mode=instructor.Mode.JSON,
)
"""
def generate_graph(input) -> KnowledgeGraph:
return client.chat.completions.create(
model='llama-3.1-8b-instant', #"llama3.2",
max_retries=5,
messages=[
{
"role": "user",
"content": f"Help me understand the following by describing it as a detailed knowledge graph: {input}",
}
],
response_model=KnowledgeGraph,
)
def visualize_knowledge_graph(kg: KnowledgeGraph):
dot = Digraph(comment="Knowledge Graph")
# Add nodes
for node in kg.nodes:
dot.node(str(node.id), node.label, color=node.color)
# Add edges
for edge in kg.edges:
dot.edge(str(edge.source), str(edge.target), label=edge.label, color=edge.color)
# Render the graph
dot.render("knowledge_graph", format="png")
def graph(query):
graph = generate_graph(query)
visualize_knowledge_graph(graph)
return "./knowledge_graph.png"