|
import gradio as gr |
|
from gradio import ChatMessage |
|
from langgraph.graph import StateGraph, END |
|
from typing import TypedDict, Annotated |
|
import operator |
|
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage |
|
from langchain_openai import ChatOpenAI |
|
import asyncio |
|
import time |
|
import os |
|
from pprint import pprint |
|
|
|
|
|
class ChatState(TypedDict): |
|
messages: Annotated[list[BaseMessage], operator.add] |
|
current_response: str |
|
agent_type: str |
|
context: dict |
|
step_info: str |
|
|
|
|
|
llm = ChatOpenAI( |
|
model="gpt-3.5-turbo", |
|
temperature=0.7, |
|
) |
|
|
|
def step1_analyzer_node(state: ChatState) -> ChatState: |
|
"""1๋จ๊ณ: ๋ฉ์์ง ๋ถ์ ๋
ธ๋""" |
|
last_message = state["messages"][-1].content |
|
|
|
|
|
analysis = { |
|
"length": len(last_message), |
|
"has_question": "?" in last_message, |
|
"language": "korean" if any(ord(c) > 127 for c in last_message) else "english", |
|
"sentiment": "positive" if any(word in last_message.lower() for word in ["์ข", "๊ฐ์ฌ", "๊ณ ๋ง์", "thanks", "good"]) else "neutral" |
|
} |
|
|
|
return { |
|
"messages": [], |
|
"current_response": "", |
|
"agent_type": "", |
|
"context": {"analysis": analysis}, |
|
"step_info": f"๐ **1๋จ๊ณ ์๋ฃ** - ๋ฉ์์ง ๋ถ์\n- ๊ธธ์ด: {analysis['length']}์\n- ์ง๋ฌธ ํฌํจ: {'์' if analysis['has_question'] else '์๋์ค'}\n- ์ธ์ด: {analysis['language']}\n- ๊ฐ์ : {analysis['sentiment']}" |
|
} |
|
|
|
def step2_classifier_node(state: ChatState) -> ChatState: |
|
"""2๋จ๊ณ: ์๋ ๋ถ๋ฅ ๋
ธ๋""" |
|
last_message = state["messages"][-1].content.lower() |
|
analysis = state["context"]["analysis"] |
|
|
|
|
|
if any(word in last_message for word in ["์ฝ๋", "ํ๋ก๊ทธ๋๋ฐ", "python", "๊ฐ๋ฐ", "ํจ์", "ํด๋์ค"]): |
|
agent_type = "programmer" |
|
confidence = 0.9 |
|
elif any(word in last_message for word in ["๋ ์จ", "๋ด์ค", "์ ๋ณด", "๊ฒ์", "์ฐพ์"]): |
|
agent_type = "informer" |
|
confidence = 0.8 |
|
elif any(word in last_message for word in ["๊ณ์ฐ", "์ํ", "๋ํ๊ธฐ", "๋นผ๊ธฐ", "๊ณฑํ๊ธฐ", "๋๋๊ธฐ"]): |
|
agent_type = "calculator" |
|
confidence = 0.95 |
|
elif any(word in last_message for word in ["์ฐฝ์", "์", "์์ค", "์ด์ผ๊ธฐ", "๊ธ"]): |
|
agent_type = "creative" |
|
confidence = 0.85 |
|
else: |
|
agent_type = "general" |
|
confidence = 0.7 |
|
|
|
context = state["context"] |
|
context["classification"] = { |
|
"agent_type": agent_type, |
|
"confidence": confidence |
|
} |
|
|
|
return { |
|
"messages": [], |
|
"current_response": "", |
|
"agent_type": agent_type, |
|
"context": context, |
|
"step_info": f"๐ฏ **2๋จ๊ณ ์๋ฃ** - ์๋ ๋ถ๋ฅ\n- ๋ถ๋ฅ ๊ฒฐ๊ณผ: {agent_type}\n- ์ ๋ขฐ๋: {confidence:.1%}\n- ๋ค์ ๋จ๊ณ: {'์ ๋ฌธ ์ฒ๋ฆฌ' if confidence > 0.8 else '์ผ๋ฐ ์ฒ๋ฆฌ'}" |
|
} |
|
|
|
def step3_context_enricher_node(state: ChatState) -> ChatState: |
|
"""3๋จ๊ณ: ์ปจํ
์คํธ ๊ฐํ ๋
ธ๋""" |
|
agent_type = state["agent_type"] |
|
|
|
|
|
enriched_context = { |
|
"programmer": { |
|
"system_prompt": "๋น์ ์ ๊ฒฝํ์ด ํ๋ถํ ์๋์ด ๊ฐ๋ฐ์์
๋๋ค. ์ฝ๋ ์์์ ํจ๊ป ๋ช
ํํ๊ณ ์ค์ฉ์ ์ธ ๋ต๋ณ์ ์ ๊ณตํ์ธ์.", |
|
"tools": ["์ฝ๋ ์คํ", "๋ฌธ์ ๊ฒ์", "๋ฒ ์คํธ ํ๋ํฐ์ค"], |
|
"style": "๊ธฐ์ ์ ์ด๊ณ ์ ํํ" |
|
}, |
|
"informer": { |
|
"system_prompt": "๋น์ ์ ์ ๋ณด ์ ๋ฌธ๊ฐ์
๋๋ค. ์ ํํ๊ณ ์ต์ ์ ์ ๋ณด๋ฅผ ๊ตฌ์กฐํ๋ ํํ๋ก ์ ๊ณตํ์ธ์.", |
|
"tools": ["์น ๊ฒ์", "ํฉํธ ์ฒดํฌ", "๋ฐ์ดํฐ ๋ถ์"], |
|
"style": "๊ฐ๊ด์ ์ด๊ณ ์์ธํ" |
|
}, |
|
"calculator": { |
|
"system_prompt": "๋น์ ์ ์ํ ์ ๋ฌธ๊ฐ์
๋๋ค. ๊ณ์ฐ ๊ณผ์ ์ ๋จ๊ณ๋ณ๋ก ์ค๋ช
ํ๊ณ ์ ํํ ๋ต์ ์ ๊ณตํ์ธ์.", |
|
"tools": ["์์ ๊ณ์ฐ", "๊ทธ๋ํ ์์ฑ", "ํต๊ณ ๋ถ์"], |
|
"style": "๋
ผ๋ฆฌ์ ์ด๊ณ ์ฒด๊ณ์ ์ธ" |
|
}, |
|
"creative": { |
|
"system_prompt": "๋น์ ์ ์ฐฝ์ ์ ๋ฌธ๊ฐ์
๋๋ค. ์์๋ ฅ์ด ํ๋ถํ๊ณ ๊ฐ์ฑ์ ์ธ ์ฝํ
์ธ ๋ฅผ ์ ์ํ์ธ์.", |
|
"tools": ["์คํ ๋ฆฌํ
๋ง", "์๊ฐ์ ๋ฌ์ฌ", "๊ฐ์ ํํ"], |
|
"style": "์ฐฝ์์ ์ด๊ณ ๊ฐ์ฑ์ ์ธ" |
|
}, |
|
"general": { |
|
"system_prompt": "๋น์ ์ ์น๊ทผํ๊ณ ๋์์ด ๋๋ AI ์ด์์คํดํธ์
๋๋ค. ์์ฐ์ค๋ฝ๊ณ ์ดํดํ๊ธฐ ์ฌ์ด ๋ต๋ณ์ ์ ๊ณตํ์ธ์.", |
|
"tools": ["์ผ๋ฐ ๋ํ", "์ ๋ณด ์ ๊ณต", "๋ฌธ์ ํด๊ฒฐ"], |
|
"style": "์น๊ทผํ๊ณ ์์ฐ์ค๋ฌ์ด" |
|
} |
|
} |
|
|
|
context = state["context"] |
|
context["enriched"] = enriched_context.get(agent_type, enriched_context["general"]) |
|
|
|
return { |
|
"messages": [], |
|
"current_response": "", |
|
"agent_type": agent_type, |
|
"context": context, |
|
"step_info": f"๐ง **3๋จ๊ณ ์๋ฃ** - ์ปจํ
์คํธ ๊ฐํ\n- ์์ด์ ํธ: {agent_type}\n- ์คํ์ผ: {context['enriched']['style']}\n- ํ์ฉ ๋๊ตฌ: {', '.join(context['enriched']['tools'][:2])}" |
|
} |
|
|
|
def step4_response_generator_node(state: ChatState) -> ChatState: |
|
"""4๋จ๊ณ: ์๋ต ์์ฑ ๋
ธ๋""" |
|
enriched_context = state["context"]["enriched"] |
|
system_prompt = enriched_context["system_prompt"] |
|
|
|
|
|
messages = [HumanMessage(content=system_prompt)] + state["messages"] |
|
|
|
try: |
|
response = llm.invoke(messages) |
|
|
|
|
|
icons = { |
|
"programmer": "๐ป", |
|
"informer": "๐", |
|
"calculator": "๐ข", |
|
"creative": "๐จ", |
|
"general": "๐ฌ" |
|
} |
|
|
|
icon = icons.get(state["agent_type"], "๐ฌ") |
|
final_response = f"{icon} **[{state['agent_type'].upper()}]**\n\n{response.content}" |
|
|
|
return { |
|
"messages": [response], |
|
"current_response": final_response, |
|
"agent_type": state["agent_type"], |
|
"context": state["context"], |
|
"step_info": f"โ
**4๋จ๊ณ ์๋ฃ** - ์๋ต ์์ฑ\n- ์ต์ข
์๋ต ์ค๋น๋จ\n- ์๋ต ๊ธธ์ด: {len(response.content)}์" |
|
} |
|
|
|
except Exception as e: |
|
error_msg = f"โ ์๋ต ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}" |
|
return { |
|
"messages": [AIMessage(content=error_msg)], |
|
"current_response": error_msg, |
|
"agent_type": state["agent_type"], |
|
"context": state["context"], |
|
"step_info": f"โ **4๋จ๊ณ ์คํจ** - ์ค๋ฅ ๋ฐ์\n- ์ค๋ฅ: {str(e)}" |
|
} |
|
|
|
def should_continue_to_classifier(state: ChatState) -> str: |
|
return "classifier" |
|
|
|
def should_continue_to_enricher(state: ChatState) -> str: |
|
return "enricher" |
|
|
|
def should_continue_to_generator(state: ChatState) -> str: |
|
return "generator" |
|
|
|
def should_end(state: ChatState) -> str: |
|
return END |
|
|
|
|
|
def create_enhanced_workflow(): |
|
workflow = StateGraph(ChatState) |
|
|
|
|
|
workflow.add_node("analyzer", step1_analyzer_node) |
|
workflow.add_node("classifier", step2_classifier_node) |
|
workflow.add_node("enricher", step3_context_enricher_node) |
|
workflow.add_node("generator", step4_response_generator_node) |
|
|
|
|
|
workflow.set_entry_point("analyzer") |
|
|
|
|
|
workflow.add_conditional_edges( |
|
"analyzer", |
|
should_continue_to_classifier, |
|
{"classifier": "classifier"} |
|
) |
|
|
|
workflow.add_conditional_edges( |
|
"classifier", |
|
should_continue_to_enricher, |
|
{"enricher": "enricher"} |
|
) |
|
|
|
workflow.add_conditional_edges( |
|
"enricher", |
|
should_continue_to_generator, |
|
{"generator": "generator"} |
|
) |
|
|
|
workflow.add_conditional_edges( |
|
"generator", |
|
should_end, |
|
{END: END} |
|
) |
|
|
|
return workflow.compile() |
|
|
|
|
|
enhanced_workflow = create_enhanced_workflow() |
|
|
|
def stream_chatbot_response(message, history): |
|
"""๊ฐ ๋จ๊ณ๋ฅผ ๋์ ํด์ ์ค์๊ฐ ํ์""" |
|
if not message.strip(): |
|
yield "", history |
|
return |
|
|
|
|
|
messages = [] |
|
for human_msg, ai_msg in history: |
|
if human_msg and not "๐" in human_msg and not "๐ฏ" in human_msg: |
|
messages.append(HumanMessage(content=human_msg)) |
|
if ai_msg and not "๐" in ai_msg and not "๐ฏ" in ai_msg and not "โก" in ai_msg: |
|
messages.append(AIMessage(content=ai_msg)) |
|
|
|
|
|
messages.append(HumanMessage(content=message)) |
|
|
|
|
|
initial_state = { |
|
"messages": messages, |
|
"current_response": "", |
|
"agent_type": "", |
|
"context": {}, |
|
"step_info": "" |
|
} |
|
|
|
|
|
current_history = history.copy() |
|
yield "", current_history |
|
time.sleep(0.3) |
|
|
|
|
|
for step_result in enhanced_workflow.stream(initial_state): |
|
node_name = list(step_result.keys())[0] |
|
node_result = step_result[node_name] |
|
pprint(step_result) |
|
if "step_info" in node_result and node_result["step_info"]: |
|
|
|
step_info = node_result["step_info"] |
|
pprint(step_info) |
|
|
|
current_history.append(ChatMessage( role="assistant", |
|
content=step_info, |
|
metadata={"title": f"{node_name}", "status": "done"})) |
|
yield "", current_history |
|
time.sleep(0.2) |
|
|
|
current_history.append(ChatMessage( role="assistant", |
|
content=step_result["generator"]["current_response"])) |
|
|
|
yield "", current_history |
|
|
|
|
|
def clear_chat(): |
|
"""์ฑํ
ํ์คํ ๋ฆฌ ์ด๊ธฐํ""" |
|
return [] |
|
|
|
|
|
def create_enhanced_gradio_interface(): |
|
with gr.Blocks(title="Enhanced LangGraph ์ฑ๋ด", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
# ๐ Enhanced LangGraph + Gradio ์ฑ๋ด |
|
|
|
**4๋จ๊ณ ์ฒ๋ฆฌ ๊ณผ์ ์ ์ค์๊ฐ์ผ๋ก ํ์ธํ ์ ์๋ AI ์ฑ๋ด** |
|
|
|
**์ฒ๋ฆฌ ๋จ๊ณ:** |
|
1. ๐ **๋ฉ์์ง ๋ถ์** - ์
๋ ฅ ๋ฉ์์ง์ ํน์ฑ ๋ถ์ |
|
2. ๐ฏ **์๋ ๋ถ๋ฅ** - ์ฌ์ฉ์ ์๋์ ๋ฐ๋ฅธ ์์ด์ ํธ ์ ํ |
|
3. ๐ง **์ปจํ
์คํธ ๊ฐํ** - ์ ๋ฌธ ๋๋ฉ์ธ๋ณ ์ปจํ
์คํธ ์ค์ |
|
4. โ
**์๋ต ์์ฑ** - ์ต์ ํ๋ ๋ต๋ณ ์์ฑ |
|
|
|
**์ง์ ์์ด์ ํธ:** ๐ป ํ๋ก๊ทธ๋๋จธ | ๐ ์ ๋ณด์ ๋ฌธ๊ฐ | ๐ข ๊ณ์ฐ๊ธฐ | ๐จ ์ฐฝ์๊ฐ | ๐ฌ ์ผ๋ฐ๋ํ |
|
""" |
|
) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
value=[], |
|
height=400, |
|
show_label=False, |
|
container=True, |
|
type="messages" |
|
) |
|
|
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
placeholder="๋ฉ์์ง๋ฅผ ์
๋ ฅํ์ธ์... (๊ฐ ์ฒ๋ฆฌ ๋จ๊ณ๊ฐ ์ค์๊ฐ์ผ๋ก ํ์๋ฉ๋๋ค)", |
|
show_label=False, |
|
scale=4, |
|
container=False |
|
) |
|
submit_btn = gr.Button("์ ์ก", scale=1, variant="primary") |
|
clear_btn = gr.Button("์ด๊ธฐํ", scale=1, variant="secondary") |
|
|
|
|
|
with gr.Row(): |
|
gr.Markdown("๐ก **ํ**: ๋ค์ํ ์ฃผ์ ๋ก ๋ํํด๋ณด์ธ์. ๊ฐ ๋จ๊ณ๋ณ ์ฒ๋ฆฌ ๊ณผ์ ์ ํ์ธํ ์ ์์ต๋๋ค!") |
|
|
|
|
|
submit_btn.click( |
|
stream_chatbot_response, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot] |
|
) |
|
|
|
msg.submit( |
|
stream_chatbot_response, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot] |
|
) |
|
|
|
|
|
clear_btn.click( |
|
clear_chat, |
|
outputs=[chatbot] |
|
) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### ๐ป ํ๋ก๊ทธ๋๋ฐ") |
|
gr.Examples( |
|
examples=[ |
|
"Python์ผ๋ก ํผ๋ณด๋์น ์์ด ํจ์ ๋ง๋๋ ๋ฐฉ๋ฒ?", |
|
"๋์
๋๋ฆฌ์ ๋ฆฌ์คํธ์ ์ฐจ์ด์ ์ ์๋ ค์ค", |
|
"ํด๋์ค์ ๊ฐ์ฒด์ ๋ํด ์ค๋ช
ํด์ค" |
|
], |
|
inputs=msg |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("### ๐ข ๊ณ์ฐ/์ํ") |
|
gr.Examples( |
|
examples=[ |
|
"25 ๊ณฑํ๊ธฐ 37์ ์ผ๋ง์ผ?", |
|
"๋ณต๋ฆฌ ๊ณ์ฐ ๋ฐฉ๋ฒ์ ์๋ ค์ค", |
|
"์ผ๊ฐํจ์์ ๋ํด ์ค๋ช
ํด์ค" |
|
], |
|
inputs=msg |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("### ๐จ ์ฐฝ์") |
|
gr.Examples( |
|
examples=[ |
|
"๋ด์ ๋ํ ์งง์ ์๋ฅผ ์จ์ค", |
|
"์ฐ์ฃผ ์ฌํ ์ด์ผ๊ธฐ๋ฅผ ๋ง๋ค์ด์ค", |
|
"์ฐฝ์์ ์ธ ์์ด๋์ด๋ฅผ ์ ์ํด์ค" |
|
], |
|
inputs=msg |
|
) |
|
|
|
return demo |
|
|
|
demo = create_enhanced_gradio_interface() |
|
if __name__ == "__main__": |
|
demo.launch() |