|
from openai import OpenAI |
|
import instructor |
|
|
|
from pydantic import BaseModel, Field |
|
from typing import List |
|
from graphviz import Digraph |
|
|
|
class Node(BaseModel, frozen=True): |
|
""" |
|
Node representing concept in the subject domain |
|
""" |
|
id: int |
|
label: str = Field(..., description = "description of the concept concept in the subject domain") |
|
color: str |
|
|
|
class Edge(BaseModel, frozen=True): |
|
""" |
|
Edge representing relationship between concepts in the subject domain |
|
""" |
|
source: int = Field(..., description = "source representing concept in the subject domain") |
|
target: int = Field(..., description = "target representing concept in the subject domain") |
|
label: str = Field(..., description = "description representing relationship between concepts in the subject domain") |
|
color: str = "black" |
|
|
|
class KnowledgeGraph(BaseModel): |
|
""" |
|
graph representation of subject domain |
|
""" |
|
nodes: List[Node] = Field(..., default_factory=list) |
|
edges: List[Edge] = Field(..., default_factory=list) |
|
|
|
from groq import Groq |
|
import os |
|
|
|
client = Groq(api_key=os.getenv("GROQ_API_KEY")) |
|
|
|
|
|
client = instructor.from_groq(client) |
|
""" |
|
client = instructor.from_openai( |
|
OpenAI( |
|
base_url="http://localhost:11434/v1", |
|
api_key="ollama", |
|
), |
|
mode=instructor.Mode.JSON, |
|
) |
|
""" |
|
def generate_graph(input) -> KnowledgeGraph: |
|
return client.chat.completions.create( |
|
model='llama-3.1-8b-instant', |
|
max_retries=5, |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": f"Help me understand the following by describing it as a detailed knowledge graph: {input}", |
|
} |
|
], |
|
response_model=KnowledgeGraph, |
|
) |
|
|
|
|
|
def visualize_knowledge_graph(kg: KnowledgeGraph): |
|
dot = Digraph(comment="Knowledge Graph") |
|
|
|
|
|
for node in kg.nodes: |
|
dot.node(str(node.id), node.label, color=node.color) |
|
|
|
|
|
for edge in kg.edges: |
|
dot.edge(str(edge.source), str(edge.target), label=edge.label, color=edge.color) |
|
|
|
|
|
dot.render("knowledge_graph", format="png") |
|
|
|
def graph(query): |
|
graph = generate_graph(query) |
|
visualize_knowledge_graph(graph) |
|
return "./knowledge_graph.png" |