law_poc / app.py
SUMANA SUMANAKUL (ING)
fix requirements
474fb03
raw
history blame
14.1 kB
import os
os.environ["OTEL_TRACES_EXPORTER"] = "none"
import gradio as gr
import uuid
from utils.chat import ChatLaborLaw
# ==============================================================================
# 1. GLOBAL INITIALIZATION (ทำครั้งเดียวตอนแอปเริ่มทำงาน)
# ==============================================================================
# --- Langfuse Handler ---
LANGFUSE_HANDLER = CallbackHandler()
# --- LLM Initialization ---
# (ปรับแก้ส่วนนี้เพื่อเลือกว่าจะใช้โมเดลไหนเป็น default)
MODEL_NAME_LLM = "jai-chat-1-3-2"
TEMPERATURE = 0
if MODEL_NAME_LLM == "jai-chat-1-3-2":
LLM_MAIN = ChatOpenAI(
model=MODEL_NAME_LLM,
api_key=os.getenv("JAI_API_KEY"),
base_url=os.getenv("CHAT_BASE_URL"),
temperature=TEMPERATURE,
max_tokens=2048,
max_retries=2,
seed=13
)
elif MODEL_NAME_LLM == "gemini-2.0-flash":
LLM_MAIN = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=os.getenv("GOOGLE_API_KEY"),
temperature=TEMPERATURE,
max_output_tokens=2048,
convert_system_message_to_human=True,
)
else:
raise ValueError(f"Unsupported LLM model '{MODEL_NAME_LLM}'.")
# --- Database and Retriever Initialization ---
MONGO_CONNECTION_STR = os.getenv('MONGO_CONNECTION_STRING')
MONGO_DATABASE = os.getenv('MONGO_DATABASE')
MONGO_COLLECTION = os.getenv('MONGO_COLLECTION')
MONGO_CLIENT = MongoClient(MONGO_CONNECTION_STR)
DB = MONGO_CLIENT[MONGO_DATABASE]
MONGO_COLLECTION_INSTANCE = DB[MONGO_COLLECTION]
RETRIEVER = RerankRetriever()
print("Global objects initialized successfully.")
# ==============================================================================
# 2. HELPER FUNCTIONS (แปลงมาจากเมธอดใน Class)
# ==============================================================================
def format_main_context(list_of_documents):
formatted_docs = []
for i, doc in enumerate(list_of_documents):
metadata = doc.metadata
formatted = f"Doc{i}\n{metadata.get('law_name', '-')}\nมาตรา\t{metadata.get('section_number', '-')}\n{doc.page_content}\nประกาศ\t{metadata.get('publication_date', '-')}\nเริ่มใช้\t{metadata.get('effective_date', '-')}"
formatted_docs.append(formatted)
return "\n\n".join(formatted_docs)
def format_ref_context(list_of_docs):
formatted_ref_docs = []
for i, doc in enumerate(list_of_docs):
formatted = f"{doc.get('law_name', '-')}\nมาตรา\t{doc.get('section_number', '-')}\n{doc.get('text', '-')}"
formatted_ref_docs.append(formatted)
return "\n\n".join(formatted_ref_docs)
def get_main_context(user_query, **kwargs):
compression_retriever = RETRIEVER.get_compression_retriever(**kwargs)
return compression_retriever.invoke(user_query)
def get_ref_context(main_context_docs):
all_reference_ids = set()
for context in main_context_docs:
references_list = context.metadata.get('references', [])
if isinstance(references_list, list):
for ref_str in references_list:
all_reference_ids.add(ref_str.replace("มาตรา", "").strip())
if not all_reference_ids:
return []
mongo_query = {"law_type": "summary", "section_number": {"$in": list(all_reference_ids)}}
projection = {"text": 1, "law_name": 1, "section_number": 1}
return list(MONGO_COLLECTION_INSTANCE.find(mongo_query, projection))
# ==============================================================================
# 3. CORE LOGIC FUNCTIONS (RAG / Non-RAG)
# ==============================================================================
async def call_rag(user_input: str, langchain_history: list) -> str:
context_docs = get_main_context(user_input, law_type="summary")
main_context_str = format_main_context(context_docs)
ref_context_docs = get_ref_context(context_docs)
ref_context_str = format_ref_context(ref_context_docs) if ref_context_docs else "-"
rag_input_data = {
"question": user_input,
"main_context": main_context_str,
"ref_context": ref_context_str,
"history": langchain_history
}
try:
prompt_messages = RAG_CHAT_PROMPT.format_messages(**rag_input_data)
response = await LLM_MAIN.ainvoke(prompt_messages, config={"callbacks": [LANGFUSE_HANDLER]})
clean_response = re.sub(r"<[^>]+>|#+", "", response.content).strip()
return clean_response
except Exception as e:
print(f"Error during RAG LLM call: {e}")
return "ขออภัย ระบบขัดข้องขณะประมวลผลคำตอบ"
async def call_non_rag(user_input: str) -> str:
prompt_messages = NON_RAG_PROMPT.format(user_input=user_input)
response = await LLM_MAIN.ainvoke(prompt_messages, config={"callbacks": [LANGFUSE_HANDLER]})
return response.content.strip() if response and response.content else "ขออภัย ระบบไม่สามารถตอบคำถามได้ในขณะนี้"
# ==============================================================================
# 4. GRADIO EVENT HANDLERS
# ==============================================================================
def initialize_session():
"""รีเซ็ต State ทั้งหมดสำหรับ Session ใหม่"""
session_id = str(uuid.uuid4())[:8]
return "", session_id, [], [] # user_input, session_id, ui_history, langchain_history
async def chat_orchestrator(prompt: str, ui_history: list, langchain_history: list):
"""
ฟังก์ชันหลักที่จัดการการสนทนาทั้งหมด
"""
if not prompt.strip():
return ui_history, langchain_history, ""
# 1. อัปเดต Langchain History ด้วยข้อความใหม่
langchain_history.append(HumanMessage(content=prompt))
# 2. จำแนกประเภทของ Input
try:
history_content_list = [msg.content for msg in langchain_history]
input_type = classify_input_type(prompt, history=history_content_list)
except Exception as e:
print(f"Error classifying input type: {e}. Defaulting to Non-RAG.")
input_type = "Non-RAG"
# 3. เรียกใช้ Flow ที่เหมาะสม
if input_type == "RAG":
ai_response = await call_rag(prompt, langchain_history)
else:
ai_response = await call_non_rag(prompt)
# 4. อัปเดต History ทั้งสองรูปแบบ
langchain_history.append(AIMessage(content=ai_response))
ui_history.append((prompt, ai_response))
# 5. ส่งค่ากลับไปอัปเดต UI และ State
return ui_history, langchain_history, "" # ui_history, langchain_history, user_input (ให้เป็นค่าว่าง)
def send_feedback(feedback: str, history: list, session_id: str):
"""บันทึก Feedback"""
if not feedback.strip(): return ""
os.makedirs("feedback", exist_ok=True)
filename = f"feedback/feedback_{session_id}.txt"
with open(filename, "a", encoding="utf-8") as f:
f.write(f"=== Feedback Received ===\nSession ID: {session_id}\nFeedback: {feedback}\nChat History:\n")
for user_msg, assistant_msg in history:
f.write(f"User: {user_msg}\nAssistant: {assistant_msg}\n")
f.write("\n--------------------------\n\n")
gr.Info("ขอบคุณสำหรับข้อเสนอแนะ!")
return ""
# ==============================================================================
# 5. GRADIO UI DEFINITION
# ==============================================================================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
gr.Markdown("# สอบถามเรื่องกฎหมายแรงงาน")
# --- States ---
# session_id_state: เก็บ ID ของ session ปัจจุบัน
# langchain_history_state: เก็บประวัติการสนทนาในรูปแบบ Langchain Message (HumanMessage, AIMessage)
session_id_state = gr.State()
langchain_history_state = gr.State([])
# --- UI Components ---
chatbot_interface = gr.Chatbot(label="ประวัติการสนทนา", height=550, bubble_styling=False, show_copy_button=True)
user_input = gr.Textbox(placeholder="พิมพ์คำถามของคุณที่นี่...", label="คำถาม", lines=2)
with gr.Row():
submit_button = gr.Button("ส่ง", variant="primary", scale=4)
clear_button = gr.Button("เริ่มการสนทนาใหม่", scale=1)
# --- Event Wiring ---
submit_button.click(
fn=chat_orchestrator,
inputs=[user_input, chatbot_interface, langchain_history_state],
outputs=[chatbot_interface, langchain_history_state, user_input]
)
user_input.submit(
fn=chat_orchestrator,
inputs=[user_input, chatbot_interface, langchain_history_state],
outputs=[chatbot_interface, langchain_history_state, user_input]
)
clear_button.click(
fn=initialize_session,
inputs=[],
outputs=[user_input, session_id_state, chatbot_interface, langchain_history_state],
queue=False
)
with gr.Accordion("ส่งข้อเสนอแนะ (Feedback)", open=False):
feedback_input = gr.Textbox(placeholder="ความคิดเห็นของคุณมีความสำคัญต่อการพัฒนาของเรา...", label="Feedback", lines=2, scale=4)
send_feedback_button = gr.Button("ส่ง Feedback")
send_feedback_button.click(
fn=send_feedback,
inputs=[feedback_input, chatbot_interface, session_id_state],
outputs=[feedback_input],
queue=False
)
demo.load(
fn=initialize_session,
inputs=[],
outputs=[user_input, session_id_state, chatbot_interface, langchain_history_state]
)
demo.queue().launch()
# # Function to initialize a new session and create chatbot instance for that session
# async def initialize_session():
# session_id = str(uuid.uuid4())[:8]
# chatbot = ChatLaborLaw()
# # chatbot = Chat("gemini-2.0-flash")
# history = []
# return "", session_id, chatbot, history
# # Function to handle user input and chatbot response
# async def chat_function(prompt, history, session_id, chatbot):
# if chatbot is None:
# return history, "", session_id, chatbot # Skip if chatbot not ready
# # Append the user's input to the message history
# history.append({"role": "user", "content": prompt})
# # Get the response from the chatbot
# response = await chatbot.chat(prompt) # ใช้ await ได้แล้ว
# # Append the assistant's response to the message history
# history.append({"role": "assistant", "content": response})
# return history, "", session_id, chatbot
# # Function to save feedback with chat history
# async def send_feedback(feedback, history, session_id, chatbot):
# os.makedirs("app/feedback", exist_ok=True)
# filename = f"app/feedback/feedback_{session_id}.txt"
# with open(filename, "a", encoding="utf-8") as f:
# f.write("=== Feedback Received ===\n")
# f.write(f"Session ID: {session_id}\n")
# f.write(f"Feedback: {feedback}\n")
# f.write("Chat History:\n")
# for msg in history:
# f.write(f"{msg['role']}: {msg['content']}\n")
# f.write("\n--------------------------\n\n")
# return "" # Clear feedback input
# # Create the Gradio interface
# with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
# gr.Markdown("# สอบถามเรื่องกฎหมายแรงงาน")
# # Initialize State
# session_state = gr.State()
# chatbot_instance = gr.State()
# chatbot_history = gr.State([])
# # Chat UI
# chatbot_interface = gr.Chatbot(type="messages", label="Chat History")
# user_input = gr.Textbox(placeholder="Type your message here...", elem_id="user_input", lines=1)
# submit_button = gr.Button("Send")
# clear_button = gr.Button("Delete Chat History")
# # Submit actions
# submit_button.click(
# fn=chat_function,
# inputs=[user_input, chatbot_history, session_state, chatbot_instance],
# outputs=[chatbot_interface, user_input, session_state, chatbot_instance]
# )
# user_input.submit(
# fn=chat_function,
# inputs=[user_input, chatbot_history, session_state, chatbot_instance],
# outputs=[chatbot_interface, user_input, session_state, chatbot_instance]
# )
# # # Clear history
# # clear_button.click(lambda: [], outputs=chatbot_interface)
# clear_button.click(
# fn=initialize_session,
# inputs=[],
# outputs=[user_input, session_state, chatbot_instance, chatbot_history]
# ).then(
# fn=lambda: gr.update(value=[]),
# inputs=[],
# outputs=chatbot_interface
# )
# # Feedback section
# with gr.Row():
# feedback_input = gr.Textbox(placeholder="Send us feedback...", label="Feedback")
# send_feedback_button = gr.Button("Send Feedback")
# send_feedback_button.click(
# fn=send_feedback,
# inputs=[feedback_input, chatbot_history, session_state, chatbot_instance],
# outputs=[feedback_input]
# )
# # Initialize session on load
# demo.load(
# fn=initialize_session,
# inputs=[],
# outputs=[user_input, session_state, chatbot_instance, chatbot_history]
# )
# # Launch
# demo.launch(share=True)