|
import os |
|
os.environ["OTEL_TRACES_EXPORTER"] = "none" |
|
|
|
import gradio as gr |
|
import uuid |
|
from utils.chat import ChatLaborLaw |
|
|
|
|
|
|
|
|
|
|
|
LANGFUSE_HANDLER = CallbackHandler() |
|
|
|
|
|
|
|
MODEL_NAME_LLM = "jai-chat-1-3-2" |
|
TEMPERATURE = 0 |
|
|
|
if MODEL_NAME_LLM == "jai-chat-1-3-2": |
|
LLM_MAIN = ChatOpenAI( |
|
model=MODEL_NAME_LLM, |
|
api_key=os.getenv("JAI_API_KEY"), |
|
base_url=os.getenv("CHAT_BASE_URL"), |
|
temperature=TEMPERATURE, |
|
max_tokens=2048, |
|
max_retries=2, |
|
seed=13 |
|
) |
|
elif MODEL_NAME_LLM == "gemini-2.0-flash": |
|
LLM_MAIN = ChatGoogleGenerativeAI( |
|
model="gemini-1.5-flash", |
|
google_api_key=os.getenv("GOOGLE_API_KEY"), |
|
temperature=TEMPERATURE, |
|
max_output_tokens=2048, |
|
convert_system_message_to_human=True, |
|
) |
|
else: |
|
raise ValueError(f"Unsupported LLM model '{MODEL_NAME_LLM}'.") |
|
|
|
|
|
MONGO_CONNECTION_STR = os.getenv('MONGO_CONNECTION_STRING') |
|
MONGO_DATABASE = os.getenv('MONGO_DATABASE') |
|
MONGO_COLLECTION = os.getenv('MONGO_COLLECTION') |
|
|
|
MONGO_CLIENT = MongoClient(MONGO_CONNECTION_STR) |
|
DB = MONGO_CLIENT[MONGO_DATABASE] |
|
MONGO_COLLECTION_INSTANCE = DB[MONGO_COLLECTION] |
|
RETRIEVER = RerankRetriever() |
|
|
|
print("Global objects initialized successfully.") |
|
|
|
|
|
|
|
|
|
|
|
def format_main_context(list_of_documents): |
|
formatted_docs = [] |
|
for i, doc in enumerate(list_of_documents): |
|
metadata = doc.metadata |
|
formatted = f"Doc{i}\n{metadata.get('law_name', '-')}\nมาตรา\t{metadata.get('section_number', '-')}\n{doc.page_content}\nประกาศ\t{metadata.get('publication_date', '-')}\nเริ่มใช้\t{metadata.get('effective_date', '-')}" |
|
formatted_docs.append(formatted) |
|
return "\n\n".join(formatted_docs) |
|
|
|
def format_ref_context(list_of_docs): |
|
formatted_ref_docs = [] |
|
for i, doc in enumerate(list_of_docs): |
|
formatted = f"{doc.get('law_name', '-')}\nมาตรา\t{doc.get('section_number', '-')}\n{doc.get('text', '-')}" |
|
formatted_ref_docs.append(formatted) |
|
return "\n\n".join(formatted_ref_docs) |
|
|
|
def get_main_context(user_query, **kwargs): |
|
compression_retriever = RETRIEVER.get_compression_retriever(**kwargs) |
|
return compression_retriever.invoke(user_query) |
|
|
|
def get_ref_context(main_context_docs): |
|
all_reference_ids = set() |
|
for context in main_context_docs: |
|
references_list = context.metadata.get('references', []) |
|
if isinstance(references_list, list): |
|
for ref_str in references_list: |
|
all_reference_ids.add(ref_str.replace("มาตรา", "").strip()) |
|
|
|
if not all_reference_ids: |
|
return [] |
|
|
|
mongo_query = {"law_type": "summary", "section_number": {"$in": list(all_reference_ids)}} |
|
projection = {"text": 1, "law_name": 1, "section_number": 1} |
|
return list(MONGO_COLLECTION_INSTANCE.find(mongo_query, projection)) |
|
|
|
|
|
|
|
|
|
|
|
async def call_rag(user_input: str, langchain_history: list) -> str: |
|
context_docs = get_main_context(user_input, law_type="summary") |
|
main_context_str = format_main_context(context_docs) |
|
|
|
ref_context_docs = get_ref_context(context_docs) |
|
ref_context_str = format_ref_context(ref_context_docs) if ref_context_docs else "-" |
|
|
|
rag_input_data = { |
|
"question": user_input, |
|
"main_context": main_context_str, |
|
"ref_context": ref_context_str, |
|
"history": langchain_history |
|
} |
|
|
|
try: |
|
prompt_messages = RAG_CHAT_PROMPT.format_messages(**rag_input_data) |
|
response = await LLM_MAIN.ainvoke(prompt_messages, config={"callbacks": [LANGFUSE_HANDLER]}) |
|
clean_response = re.sub(r"<[^>]+>|#+", "", response.content).strip() |
|
return clean_response |
|
except Exception as e: |
|
print(f"Error during RAG LLM call: {e}") |
|
return "ขออภัย ระบบขัดข้องขณะประมวลผลคำตอบ" |
|
|
|
async def call_non_rag(user_input: str) -> str: |
|
prompt_messages = NON_RAG_PROMPT.format(user_input=user_input) |
|
response = await LLM_MAIN.ainvoke(prompt_messages, config={"callbacks": [LANGFUSE_HANDLER]}) |
|
return response.content.strip() if response and response.content else "ขออภัย ระบบไม่สามารถตอบคำถามได้ในขณะนี้" |
|
|
|
|
|
|
|
|
|
|
|
def initialize_session(): |
|
"""รีเซ็ต State ทั้งหมดสำหรับ Session ใหม่""" |
|
session_id = str(uuid.uuid4())[:8] |
|
return "", session_id, [], [] |
|
|
|
async def chat_orchestrator(prompt: str, ui_history: list, langchain_history: list): |
|
""" |
|
ฟังก์ชันหลักที่จัดการการสนทนาทั้งหมด |
|
""" |
|
if not prompt.strip(): |
|
return ui_history, langchain_history, "" |
|
|
|
|
|
langchain_history.append(HumanMessage(content=prompt)) |
|
|
|
|
|
try: |
|
history_content_list = [msg.content for msg in langchain_history] |
|
input_type = classify_input_type(prompt, history=history_content_list) |
|
except Exception as e: |
|
print(f"Error classifying input type: {e}. Defaulting to Non-RAG.") |
|
input_type = "Non-RAG" |
|
|
|
|
|
if input_type == "RAG": |
|
ai_response = await call_rag(prompt, langchain_history) |
|
else: |
|
ai_response = await call_non_rag(prompt) |
|
|
|
|
|
langchain_history.append(AIMessage(content=ai_response)) |
|
ui_history.append((prompt, ai_response)) |
|
|
|
|
|
return ui_history, langchain_history, "" |
|
|
|
def send_feedback(feedback: str, history: list, session_id: str): |
|
"""บันทึก Feedback""" |
|
if not feedback.strip(): return "" |
|
os.makedirs("feedback", exist_ok=True) |
|
filename = f"feedback/feedback_{session_id}.txt" |
|
with open(filename, "a", encoding="utf-8") as f: |
|
f.write(f"=== Feedback Received ===\nSession ID: {session_id}\nFeedback: {feedback}\nChat History:\n") |
|
for user_msg, assistant_msg in history: |
|
f.write(f"User: {user_msg}\nAssistant: {assistant_msg}\n") |
|
f.write("\n--------------------------\n\n") |
|
gr.Info("ขอบคุณสำหรับข้อเสนอแนะ!") |
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo: |
|
gr.Markdown("# สอบถามเรื่องกฎหมายแรงงาน") |
|
|
|
|
|
|
|
|
|
session_id_state = gr.State() |
|
langchain_history_state = gr.State([]) |
|
|
|
|
|
chatbot_interface = gr.Chatbot(label="ประวัติการสนทนา", height=550, bubble_styling=False, show_copy_button=True) |
|
user_input = gr.Textbox(placeholder="พิมพ์คำถามของคุณที่นี่...", label="คำถาม", lines=2) |
|
with gr.Row(): |
|
submit_button = gr.Button("ส่ง", variant="primary", scale=4) |
|
clear_button = gr.Button("เริ่มการสนทนาใหม่", scale=1) |
|
|
|
|
|
submit_button.click( |
|
fn=chat_orchestrator, |
|
inputs=[user_input, chatbot_interface, langchain_history_state], |
|
outputs=[chatbot_interface, langchain_history_state, user_input] |
|
) |
|
user_input.submit( |
|
fn=chat_orchestrator, |
|
inputs=[user_input, chatbot_interface, langchain_history_state], |
|
outputs=[chatbot_interface, langchain_history_state, user_input] |
|
) |
|
clear_button.click( |
|
fn=initialize_session, |
|
inputs=[], |
|
outputs=[user_input, session_id_state, chatbot_interface, langchain_history_state], |
|
queue=False |
|
) |
|
|
|
with gr.Accordion("ส่งข้อเสนอแนะ (Feedback)", open=False): |
|
feedback_input = gr.Textbox(placeholder="ความคิดเห็นของคุณมีความสำคัญต่อการพัฒนาของเรา...", label="Feedback", lines=2, scale=4) |
|
send_feedback_button = gr.Button("ส่ง Feedback") |
|
|
|
send_feedback_button.click( |
|
fn=send_feedback, |
|
inputs=[feedback_input, chatbot_interface, session_id_state], |
|
outputs=[feedback_input], |
|
queue=False |
|
) |
|
|
|
demo.load( |
|
fn=initialize_session, |
|
inputs=[], |
|
outputs=[user_input, session_id_state, chatbot_interface, langchain_history_state] |
|
) |
|
|
|
demo.queue().launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|