File size: 2,308 Bytes
91b56ef 6093500 91b56ef 6093500 91b56ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import instructor
from pydantic import BaseModel, Field
from typing import List
from graphviz import Digraph
class Node(BaseModel, frozen=True):
"""
Node representing concept in the subject domain
"""
id: int
label: str = Field(..., description = "description of the concept concept in the subject domain")
color: str
class Edge(BaseModel, frozen=True):
"""
Edge representing relationship between concepts in the subject domain
"""
source: int = Field(..., description = "source representing concept in the subject domain")
target: int = Field(..., description = "target representing concept in the subject domain")
label: str = Field(..., description = "description representing relationship between concepts in the subject domain")
color: str = "black"
class KnowledgeGraph(BaseModel):
"""
graph representation of subject domain
"""
nodes: List[Node] = Field(..., default_factory=list)
edges: List[Edge] = Field(..., default_factory=list)
from groq import Groq
import os
# Initialize with API key
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
# Enable instructor patches for Groq client
client = instructor.from_groq(client)
"""
from openai import OpenAI
client = instructor.from_openai(
OpenAI(
base_url="http://localhost:11434/v1",
api_key="ollama",
),
mode=instructor.Mode.JSON,
)
"""
def generate_graph(input) -> KnowledgeGraph:
return client.chat.completions.create(
model='llama-3.1-8b-instant', #"llama3.2",
max_retries=5,
messages=[
{
"role": "user",
"content": f"Help me understand the following by describing it as a detailed knowledge graph: {input}",
}
],
response_model=KnowledgeGraph,
)
def visualize_knowledge_graph(kg: KnowledgeGraph):
dot = Digraph(comment="Knowledge Graph")
# Add nodes
for node in kg.nodes:
dot.node(str(node.id), node.label, color=node.color)
# Add edges
for edge in kg.edges:
dot.edge(str(edge.source), str(edge.target), label=edge.label, color=edge.color)
# Render the graph
dot.render("knowledge_graph", format="png")
def graph(query):
graph = generate_graph(query)
visualize_knowledge_graph(graph)
return "knowledge_graph.png" |