|
import json |
|
import asyncio |
|
from typing import List |
|
from typing_extensions import TypedDict |
|
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate |
|
from langgraph.graph import StateGraph, END |
|
from src.utils.api_key_manager import with_api_manager |
|
from src.helpers.helper import remove_markdown |
|
|
|
|
|
class GraphState(TypedDict): |
|
initial_prompt: str |
|
plan: str |
|
write_steps: List[dict] |
|
final_json: str |
|
|
|
@with_api_manager(temperature=0.0, top_p=1.0) |
|
def planning_node(state: GraphState, *, llm) -> GraphState: |
|
print("\n---PLANNING---\n") |
|
|
|
initial_prompt = state['initial_prompt'] |
|
|
|
plan_template = \ |
|
f"""You need to create a structured JSON based on the following instructions: |
|
{initial_prompt} |
|
|
|
Rules: |
|
1. Outline a multi-step plan (one step per line) that will guide the creation of the final JSON. |
|
2. You must create the entire plan yourself without asking others to create it for you. |
|
2. The steps should be as follows: |
|
- Each step should be a high-level task or section of the JSON. |
|
- Check if breaking down each step into smaller, low-level sub-tasks or sections is required |
|
- If yes, ONLY include the sub-steps (one sub-step per line). |
|
3. The plan should be concise and clear, and each step and sub-step should be distinct. |
|
4. The plan should be unformatted and in plain text. DO NOT even use bullet points or new lines. |
|
4. The number of steps should be as less as possible, but still enough to cover ALL sections. |
|
5. If the user request contains any specific details, include them in the plan. |
|
6. DO NOT create the final content, just the plan/outline. |
|
7. DO NOT include any markdown or formatting in the plan.""" |
|
|
|
chat_template = ChatPromptTemplate.from_messages([ |
|
HumanMessagePromptTemplate.from_template("{text}"), |
|
] |
|
) |
|
prompt = chat_template.invoke({"text": plan_template}) |
|
|
|
response = llm.invoke(prompt) |
|
plan = response.content.strip() |
|
|
|
|
|
state['plan'] = remove_markdown(plan) |
|
print(plan) |
|
return state |
|
|
|
@with_api_manager(temperature=0.0, top_p=1.0) |
|
def writing_node_sync(state: GraphState, *, llm) -> GraphState: |
|
print("\n---WRITING THE JSON---\n") |
|
|
|
initial_prompt = state['initial_prompt'] |
|
plan = state['plan'] |
|
plan = plan.strip() |
|
|
|
|
|
plan_lines = plan.split('\n') |
|
|
|
|
|
partial_jsons: List[dict] = [] |
|
|
|
|
|
for idx, step_line in enumerate(plan_lines): |
|
if len(step_line.strip()) > 0: |
|
step_prompt_text = \ |
|
f"""You are creating part {idx+1} of the final JSON document. |
|
User request: |
|
{initial_prompt} |
|
|
|
Plan step (outline): |
|
{step_line.strip()} |
|
|
|
Rules: |
|
1. You need to write the JSON data for this step. |
|
2. The JSON should be structured and valid. |
|
3. If the user request contains any specific details, include them in the JSON. |
|
4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit. |
|
5. Respond ONLY with valid JSON for this step without any markdown or formatting.""" |
|
|
|
chat_template = ChatPromptTemplate.from_messages([ |
|
HumanMessagePromptTemplate.from_template("{text}"), |
|
] |
|
) |
|
prompt = chat_template.invoke({"text": step_prompt_text}) |
|
|
|
response = llm.invoke(prompt) |
|
step_result = response.content.strip() |
|
|
|
|
|
try: |
|
cleaned_result = remove_markdown(step_result) |
|
partial_obj = json.loads(cleaned_result) |
|
except json.JSONDecodeError: |
|
|
|
raise Exception(f"Failed to parse JSON data for step {idx+1}") |
|
|
|
|
|
|
|
|
|
partial_jsons.append(partial_obj) |
|
|
|
|
|
state['write_steps'] = partial_jsons |
|
return state |
|
|
|
@with_api_manager(temperature=0.0, top_p=1.0) |
|
async def writing_node_async(state: GraphState, *, llm) -> GraphState: |
|
async def get_partial_json(idx: int, step_line: str) -> dict: |
|
step_prompt_text = \ |
|
f"""You are creating part {idx+1} of the final JSON document. |
|
User request: |
|
{initial_prompt} |
|
|
|
Plan step (outline): |
|
{step_line.strip()} |
|
|
|
Rules: |
|
1. You need to write the JSON data for this step. |
|
2. The JSON should be structured and valid. |
|
3. If the user request contains any specific details, include them in the JSON. |
|
4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit. |
|
5. Respond ONLY with valid JSON for this step without any markdown or formatting.""" |
|
|
|
chat_template = ChatPromptTemplate.from_messages([ |
|
HumanMessagePromptTemplate.from_template("{text}"), |
|
] |
|
) |
|
prompt = chat_template.invoke({"text": step_prompt_text}) |
|
|
|
response = await llm.ainvoke(prompt) |
|
step_result = response.content.strip() |
|
|
|
cleaned_result = remove_markdown(step_result) |
|
try: |
|
partial_obj = json.loads(cleaned_result) |
|
except json.JSONDecodeError as e: |
|
raise Exception(f"Failed to parse JSON data for step {idx+1}: {e}") |
|
|
|
|
|
|
|
return partial_obj |
|
|
|
print("\n---WRITING THE JSON---\n") |
|
|
|
initial_prompt = state['initial_prompt'] |
|
plan = state['plan'].strip() |
|
|
|
plan_lines = plan.split('\n') |
|
partial_jsons: List[dict] = [] |
|
|
|
|
|
tasks = [] |
|
for idx, line in enumerate(plan_lines): |
|
if len(line.strip()) > 0: |
|
tasks.append(asyncio.create_task(get_partial_json(idx, line))) |
|
|
|
|
|
partial_jsons = await asyncio.gather(*tasks) |
|
|
|
|
|
state['write_steps'] = list(partial_jsons) |
|
return state |
|
|
|
def consolidation_node(state: GraphState) -> GraphState: |
|
print("\n---CONSOLIDATING THE JSON---\n") |
|
|
|
plan = state['plan'] |
|
partial_jsons = state['write_steps'] |
|
|
|
final_obj = { |
|
"plan": plan, |
|
"steps": partial_jsons |
|
} |
|
|
|
|
|
final_json_str = json.dumps(final_obj, ensure_ascii=False, indent=2) |
|
|
|
|
|
state['final_json'] = final_json_str |
|
return state |
|
|
|
def create_workflow_sync() -> StateGraph: |
|
workflow = StateGraph(GraphState) |
|
|
|
|
|
workflow.add_node("planning_node", planning_node) |
|
workflow.add_node("writing_node", writing_node_sync) |
|
workflow.add_node("consolidation_node", consolidation_node) |
|
|
|
|
|
workflow.set_entry_point("planning_node") |
|
|
|
|
|
workflow.add_edge("planning_node", "writing_node") |
|
workflow.add_edge("writing_node", "consolidation_node") |
|
|
|
workflow.add_edge("consolidation_node", END) |
|
|
|
return workflow.compile() |
|
|
|
def create_workflow_async() -> StateGraph: |
|
workflow = StateGraph(GraphState) |
|
|
|
|
|
workflow.add_node("planning_node", planning_node) |
|
workflow.add_node("writing_node", writing_node_async) |
|
workflow.add_node("consolidation_node", consolidation_node) |
|
|
|
|
|
workflow.set_entry_point("planning_node") |
|
|
|
|
|
workflow.add_edge("planning_node", "writing_node") |
|
workflow.add_edge("writing_node", "consolidation_node") |
|
|
|
workflow.add_edge("consolidation_node", END) |
|
|
|
return workflow.compile() |
|
|
|
if __name__ == "__main__": |
|
import time |
|
|
|
test_instruction = "Write a 1500-word piece on the HBO TV show Westworld, covering major characters, \ |
|
themes of AI and consciousness, and how the story might have continued had it not been cancelled. \ |
|
Include specific details, quotes, and references to the show and its creators.\ |
|
Do not include any spoilers for the climax of the show's final season." |
|
|
|
app = create_workflow_async() |
|
|
|
|
|
state_input: GraphState = { |
|
"initial_prompt": test_instruction, |
|
"plan": "", |
|
"write_steps": [], |
|
"final_json": "" |
|
} |
|
start = time.time() |
|
final_state = asyncio.run(app.ainvoke(state_input)) |
|
end = time.time() |
|
|
|
|
|
print("\n===== FINAL JSON OUTPUT =====\n") |
|
print(final_state['final_json']) |
|
print("=============================\n") |
|
|
|
print("\n===== PERFOMANCE =====\n") |
|
print(f"Time taken: {end-start:.2f} seconds") |
|
print("======================\n") |
|
|