|
from typing import List |
|
|
|
import instructor |
|
from graphviz import Digraph |
|
from pydantic import BaseModel, Field |
|
|
|
from groq import Groq |
|
import os |
|
|
|
|
|
client = Groq(api_key=os.getenv("GROQ_API_KEY")) |
|
|
|
|
|
client = instructor.from_groq(client) |
|
""" |
|
from openai import OpenAI |
|
client = instructor.from_openai( |
|
OpenAI( |
|
base_url="http://localhost:11434/v1", |
|
api_key="ollama", |
|
), |
|
mode=instructor.Mode.JSON, |
|
) |
|
""" |
|
llm = 'llama-3.1-8b-instant' if os.getenv("GROQ_API_KEY") else "llama3.2" |
|
|
|
class Node(BaseModel, frozen=True): |
|
""" |
|
Node representing concept in the subject domain |
|
""" |
|
id: int= Field(..., |
|
description="unique id of the concept in the subject domain, used for deduplication, design a scheme allows multiple concept") |
|
label: str = Field(..., description="description of the concept in the subject domain") |
|
color: str = "orange" |
|
|
|
|
|
class Edge(BaseModel, frozen=True): |
|
""" |
|
Edge representing relationship between concepts in the subject domain, source depends on target |
|
""" |
|
source: int = Field(..., description="source representing concept in the subject domain") |
|
target: int = Field(..., description="target representing concept in the subject domain") |
|
label: str = Field(..., description="description representing relationship between concepts in the subject domain") |
|
color: str = "black" |
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
class KnowledgeGraph(BaseModel): |
|
""" |
|
KnowledgeGraph is graph representation of concepts in the subject domain |
|
""" |
|
nodes: Optional[List[Node]] = Field(..., default_factory=list) |
|
edges: Optional[List[Edge]] = Field(..., default_factory=list) |
|
|
|
def update(self, other: "KnowledgeGraph") -> "KnowledgeGraph": |
|
"""Updates the current graph with the other graph, deduplicating nodes and edges.""" |
|
return KnowledgeGraph( |
|
nodes=list(set(self.nodes + other.nodes)), |
|
edges=list(set(self.edges + other.edges)), |
|
) |
|
|
|
def draw(self, prefix: str = "knowledge_graph"): |
|
dot = Digraph(comment="Knowledge Graph") |
|
|
|
for node in self.nodes: |
|
dot.node(str(node.id), node.label, color=node.color) |
|
|
|
for edge in self.edges: |
|
dot.edge( |
|
str(edge.source), str(edge.target), label=edge.label, color=edge.color |
|
) |
|
dot.render(prefix, format="png", view=True) |
|
|
|
|
|
from typing import Iterable |
|
from textwrap import dedent |
|
|
|
|
|
def generate_graph(q, input) -> KnowledgeGraph: |
|
return client.chat.completions.create( |
|
model=llm, |
|
max_retries=5, |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": dedent(f"""Help me understand the following by describing it as a detailed knowledge graph: |
|
### Question: {q} |
|
### Context: {input} |
|
### Instruction: |
|
Generate at least 5 concepts |
|
Generate at least 3 relationship |
|
### Output Format: |
|
Node with id, label for description of the concept |
|
Edge with source's id, target's id, label for description of the relationship between source concept and target concept |
|
"""), |
|
|
|
} |
|
], |
|
response_model=KnowledgeGraph) |
|
|
|
|
|
class Subissue(BaseModel): |
|
subissue_title: str |
|
point: List[str] = Field(default_factory=list, description="Specific aspect or component of the subissue") |
|
|
|
|
|
def expandIssue(input) -> Iterable[Subissue]: |
|
response = client.chat.completions.create( |
|
model=llm, |
|
max_retries=3, |
|
response_model=Iterable[Subissue], |
|
temperature=0.1, |
|
messages=[ |
|
|
|
{ |
|
"role": "user", |
|
"content": dedent(f""" |
|
As a McKinsey Consultant, perform MECE decomposition of the question. |
|
### Requirements |
|
1. Return 3 subissues minimum |
|
2. Each sub-issue has 3 bullet points, which each new point beginning with a * |
|
3. Use EXACT format: |
|
|
|
- [Sub-issue 1.1 title] |
|
* [point 1] |
|
* [point 2] |
|
* [point 3] |
|
- [Sub-issue 1.2 title] |
|
* [point 1] |
|
* [point 2] |
|
* [point 3] |
|
- [Sub-issue 1.3 title] |
|
* [point 1] |
|
* [point 2] |
|
* [point 3] |
|
|
|
4. return nothing else |
|
### Question: {input} |
|
"""), |
|
}, |
|
], |
|
) |
|
|
|
return response |
|
|
|
|
|
def graph(query): |
|
queryx = expandIssue(query) |
|
|
|
graph = generate_graph(query, str(queryx)) |
|
return graph.json() |
|
|