law_poc / utils /chat.py
SUMANA SUMANAKUL (ING)
commit
8e5a9dd
raw
history blame
10.8 kB
import os, re
# os.environ["OTEL_TRACES_EXPORTER"] = "none"
os.environ["OTEL_SDK_DISABLED"] = "true"
import uuid
from dotenv import load_dotenv
from utils.chat_prompts import RAG_CHAT_PROMPT, NON_RAG_PROMPT
from utils.reranker import RerankRetriever
from utils.input_classifier import classify_input_type
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from pymongo import MongoClient
from langfuse.langchain import CallbackHandler
from langfuse import observe
load_dotenv()
# MongoDB configurations
mongo_username = os.environ.get('MONGO_USERNAME')
mongo_password = os.environ.get('MONGO_PASSWORD')
mongo_database = os.environ.get('MONGO_DATABASE')
mongo_connection_str = os.environ.get('MONGO_CONNECTION_STRING')
mongo_collection_name = os.environ.get('MONGO_COLLECTION')
class ChatLaborLaw:
def __init__(self, model_name_llm="jai-chat-1-3-2", temperature=0):
self.session_id = str(uuid.uuid4())[:8]
# ----- Langfuse -----
self.langfuse_handler = CallbackHandler(
)
self.history = [] # Store Langchain Message objects
self.model_name_llm = model_name_llm
self.retriever = RerankRetriever()
self.client = MongoClient(mongo_connection_str)
self.db = self.client[mongo_database]
self.collection = self.db[mongo_collection_name]
# --- LLM Initialization ---
if model_name_llm == "jai-chat-1-3-2":
self.llm_main = ChatOpenAI(
model=model_name_llm,
api_key=os.getenv("JAI_API_KEY"),
base_url=os.getenv("CHAT_BASE_URL"),
temperature=temperature,
max_tokens=2048,
max_retries=2,
seed=13
)
elif model_name_llm == "gemini-2.0-flash":
GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
if not GEMINI_API_KEY:
raise ValueError("GOOGLE_API_KEY (for Gemini) not found in environment variables.")
common_gemini_config = {
"google_api_key": GEMINI_API_KEY,
"temperature": temperature,
"max_output_tokens": 2048,
"convert_system_message_to_human": True,
}
self.llm_main = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
**common_gemini_config)
else:
raise ValueError(f"Unsupported LLM model '{model_name_llm}'.")
self.history = [] # Store Langchain Message objects
# ----- Context Retrieval -----
@observe(name='main_context')
def get_main_context(self, user_query, **kwargs):
# note ต้อง get ทุกครั้งไหม กรณีอะไรที่จะเปลี่ยน
# note ต้องมี classifier มาเพื่อตัดสิน filters -- * ถ้ามีระบุเวลา ก็ต้องไปคำนวน แล้วเอาจาก official_version แทน
compression_retriever = self.retriever.get_compression_retriever(**kwargs)
main_comtext_docs = compression_retriever.invoke(user_query)
return main_comtext_docs
@observe(name='ref_context')
def get_ref_context(self, main_context_docs):
"""
ค้นหา Context ของมาตราที่ถูกอ้างอิงจาก MongoDB
โดยใช้ $in operator เพื่อประสิทธิภาพสูงสุด
"""
all_reference_docs = []
for context in main_context_docs:
references_list = context.metadata.get('references', [])
if not isinstance(references_list, list) or not references_list:
continue # ข้ามไป context ถัดไปถ้าไม่มีอ้างอิง
ref_numbers = [
ref_str.replace("มาตรา", "").strip()
for ref_str in references_list
]
# query $in : มาตรานั้นๆ
mongo_query = {
"law_type": "summary",
"section_number": {"$in": ref_numbers}
}
projection = {
"_id": 1,
"text": 1,
"document_type": 1,
"law_type": 1,
"law_name": 1,
# "publication_date": 1,
# "effective_date": 1,
# "publication_date_utc": 1,
# "effective_date_utc": 1,
# "royal_gazette_volume": 1,
# "royal_gazette_no": 1,
# "royal_gazette_page": 1,
"chunk_type": 1,
"section_number": 1
}
results = self.collection.find(mongo_query, projection)
all_reference_docs.extend(list(results))
# ลบอันที่ซ้ำ
ref_docs_by_id = {}
for doc in all_reference_docs:
ref_docs_by_id[doc["_id"]] = doc # ถ้ามี _id ซ้ำกัน จะ overwrite
return list(ref_docs_by_id.values())
# handle main context
# ต้องเอา law_name, section_number (มาตรา), publication_date(ถ้ามี), effective_date(ถ้ามี)
def format_main_context(self, list_of_documents):
"""
input: list of Document (Langchain)
output: text --> to forward to prompt
"""
formatted_docs = []
for i, doc in enumerate(list_of_documents):
law_name = doc.metadata.get('law_name', '-')
section_number = doc.metadata.get('section_number', '-')
publication_date = doc.metadata.get('publication_date', '-') # ไม่ได้มีทุกอัน
effective_date = doc.metadata.get('effective_date', '-') # ไม่ได้มีทุกอัน
content = doc.page_content
formatted = "\n".join([
f"Doc{i}",
f"{law_name}",
f"มาตรา\t{section_number}",
content,
f"ประกาศ\t{publication_date}",
f"เริ่มใช้\t{effective_date}"
])
formatted_docs.append(formatted)
return "\n\n".join(formatted_docs)
def format_ref_context(self, list_of_docs):
formatted_ref_docs = []
for i, doc in enumerate(list_of_docs):
law_name = doc.get('law_name', '-')
section_number = doc.get('section_number', '-')
content = doc.get('text', '-')
formatted = "\n".join([
f"{law_name}",
f"มาตรา\t{section_number}",
content,
])
formatted_ref_docs.append(formatted)
return "\n\n".join(formatted_ref_docs)
# ----- Chat! -----
# History
def append_history(self, message: [HumanMessage, AIMessage]):
self.history.append(message)
def get_formatted_history_for_llm(self, n_turns: int = 3) -> list:
"""Returns the last n_turns of history as a list of Message objects."""
return self.history[-(n_turns * 2):]
# Classify
@observe(name='classify_input_type')
def classify_input(self, user_input: str) -> str:
history_content_list = [msg.content for msg in self.history] # เอาแค่ข้อความมา
return classify_input_type(user_input, history=history_content_list)
# Chat
@observe(name="chat_flow")
async def chat(self, user_input: str) -> str:
history_before_current_input = self.history[:]
self.append_history(HumanMessage(content=user_input))
try:
input_type = self.classify_input(user_input)
except Exception as e:
print(f"Error classifying input type: {e}. Defaulting to Non-RAG.")
input_type = "Non-RAG"
ai_response_content = ""
if input_type == "RAG":
# print("[RAG FLOW]")
ai_response_content = await self.call_rag(user_input) #, history_before_current_input)
else:
# print(f"[{input_type} FLOW (Treated as NON-RAG)]")
ai_response_content = await self.call_non_rag(user_input)
self.append_history(AIMessage(content=ai_response_content))
# print(f"AI:::: {ai_response_content}")
return ai_response_content
@observe(name='rag_flow')
async def call_rag(self, user_input: str) -> str:
# main context
context_docs = self.get_main_context(user_input, law_type="summary") # chunk_type='section'
# print(context_docs)
main_context_str = self.format_main_context(context_docs)
# ref context
ref_context_docs = self.get_ref_context(context_docs)
try:
ref_context_str = self.format_ref_context(ref_context_docs)
except:
ref_context_str = "-"
history_for_llm_prompt = self.get_formatted_history_for_llm(n_turns=3)
rag_input_data = {
"question": user_input,
"main_context": main_context_str,
"ref_context": ref_context_str,
"history": history_for_llm_prompt
}
try:
prompt_messages = RAG_CHAT_PROMPT.format_messages(**rag_input_data)
response = await self.llm_main.ainvoke(prompt_messages, config={"callbacks": [self.langfuse_handler]})
responsestring = response.content
clean_response = re.sub(r"<[^>]+>", "", responsestring)
clean_response = re.sub(r"#+", "", clean_response)
clean_response = clean_response.strip()
# return response.content.strip()
return clean_response
except Exception as e:
print(f"Error during RAG LLM call: {e}")
return "Sorry, I encountered an error while generating the response."
@observe(name='non_rag_flow')
async def call_non_rag(self, user_input: str) -> str:
prompt_messages = NON_RAG_PROMPT.format(user_input=user_input)
response = await self.llm_main.ainvoke(prompt_messages, config={"callbacks": [self.langfuse_handler]})
# ป้องกัน content เป็น None
if not response or not response.content:
return "ขออภัย ระบบไม่สามารถตอบคำถามได้ในขณะนี้"
return response.content.strip()