import os, re # os.environ["OTEL_TRACES_EXPORTER"] = "none" # os.environ["OTEL_SDK_DISABLED"] = "true" os.environ["OTEL_TRACES_EXPORTER"] = "console" import uuid from dotenv import load_dotenv from utils.chat_prompts import RAG_CHAT_PROMPT, NON_RAG_PROMPT from utils.reranker import RerankRetriever from utils.input_classifier import classify_input_type from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, AIMessage from langchain_google_genai import ChatGoogleGenerativeAI from pymongo import MongoClient from langfuse.langchain import CallbackHandler from langfuse import observe load_dotenv() # MongoDB configurations mongo_username = os.environ.get('MONGO_USERNAME') mongo_password = os.environ.get('MONGO_PASSWORD') mongo_database = os.environ.get('MONGO_DATABASE') mongo_connection_str = os.environ.get('MONGO_CONNECTION_STRING') mongo_collection_name = os.environ.get('MONGO_COLLECTION') class ChatLaborLaw: def __init__(self, model_name_llm="jai-chat-1-3-2", temperature=0): self.session_id = str(uuid.uuid4())[:8] # ----- Langfuse ----- self.langfuse_handler = CallbackHandler( ) self.history = [] # Store Langchain Message objects self.model_name_llm = model_name_llm self.retriever = RerankRetriever() self.client = MongoClient(mongo_connection_str) self.db = self.client[mongo_database] self.collection = self.db[mongo_collection_name] # --- LLM Initialization --- if model_name_llm == "jai-chat-1-3-2": self.llm_main = ChatOpenAI( model=model_name_llm, api_key=os.getenv("JAI_API_KEY"), base_url=os.getenv("CHAT_BASE_URL"), temperature=temperature, max_tokens=2048, max_retries=2, seed=13 ) elif model_name_llm == "gemini-2.0-flash": GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY") if not GEMINI_API_KEY: raise ValueError("GOOGLE_API_KEY (for Gemini) not found in environment variables.") common_gemini_config = { "google_api_key": GEMINI_API_KEY, "temperature": temperature, "max_output_tokens": 2048, "convert_system_message_to_human": True, } self.llm_main = ChatGoogleGenerativeAI( model="gemini-2.0-flash", **common_gemini_config) else: raise ValueError(f"Unsupported LLM model '{model_name_llm}'.") self.history = [] # Store Langchain Message objects # ----- Context Retrieval ----- @observe(name='main_context') def get_main_context(self, user_query, **kwargs): # note ต้อง get ทุกครั้งไหม กรณีอะไรที่จะเปลี่ยน # note ต้องมี classifier มาเพื่อตัดสิน filters -- * ถ้ามีระบุเวลา ก็ต้องไปคำนวน แล้วเอาจาก official_version แทน compression_retriever = self.retriever.get_compression_retriever(**kwargs) main_comtext_docs = compression_retriever.invoke(user_query) return main_comtext_docs @observe(name='ref_context') def get_ref_context(self, main_context_docs): """ ค้นหา Context ของมาตราที่ถูกอ้างอิงจาก MongoDB โดยใช้ $in operator เพื่อประสิทธิภาพสูงสุด """ all_reference_docs = [] for context in main_context_docs: references_list = context.metadata.get('references', []) if not isinstance(references_list, list) or not references_list: continue # ข้ามไป context ถัดไปถ้าไม่มีอ้างอิง ref_numbers = [ ref_str.replace("มาตรา", "").strip() for ref_str in references_list ] # query $in : มาตรานั้นๆ mongo_query = { "law_type": "summary", "section_number": {"$in": ref_numbers} } projection = { "_id": 1, "text": 1, "document_type": 1, "law_type": 1, "law_name": 1, "chapter":1, # "publication_date": 1, # "effective_date": 1, # "publication_date_utc": 1, # "effective_date_utc": 1, # "royal_gazette_volume": 1, # "royal_gazette_no": 1, # "royal_gazette_page": 1, "chunk_type": 1, "section_number": 1 } results = self.collection.find(mongo_query, projection) all_reference_docs.extend(list(results)) # ลบอันที่ซ้ำ ref_docs_by_id = {} for doc in all_reference_docs: ref_docs_by_id[doc["_id"]] = doc # ถ้ามี _id ซ้ำกัน จะ overwrite return list(ref_docs_by_id.values()) # handle main context # ต้องเอา law_name, section_number (มาตรา), publication_date(ถ้ามี), effective_date(ถ้ามี) def format_main_context(self, list_of_documents): """ input: list of Document (Langchain) output: text --> to forward to prompt """ formatted_docs = [] for i, doc in enumerate(list_of_documents): law_name = doc.metadata.get('law_name', '-') chapter = doc.metadata.get('chapter', '-') section_number = doc.metadata.get('section_number', '-') publication_date = doc.metadata.get('publication_date', '-') # ไม่ได้มีทุกอัน effective_date = doc.metadata.get('effective_date', '-') # ไม่ได้มีทุกอัน content = doc.page_content formatted = "\n".join([ f"Doc{i}", f"{law_name}", f"{chapter}" f"มาตรา\t{section_number}", content, f"ประกาศ\t{publication_date}", f"เริ่มใช้\t{effective_date}" ]) formatted_docs.append(formatted) return "\n\n".join(formatted_docs) def format_ref_context(self, list_of_docs): formatted_ref_docs = [] for i, doc in enumerate(list_of_docs): law_name = doc.get('law_name', '-') chapter = doc.get('chapter', '-') section_number = doc.get('section_number', '-') content = doc.get('text', '-') formatted = "\n".join([ f"{law_name}", f"{chapter}" f"มาตรา\t{section_number}", content, ]) formatted_ref_docs.append(formatted) return "\n\n".join(formatted_ref_docs) # ----- Chat! ----- # History def append_history(self, message: [HumanMessage, AIMessage]): self.history.append(message) def get_formatted_history_for_llm(self, n_turns: int = 3) -> list: """Returns the last n_turns of history as a list of Message objects.""" return self.history[-(n_turns * 2):] # Classify @observe(name='classify_input_type') def classify_input(self, user_input: str) -> str: history_content_list = [msg.content for msg in self.history] # เอาแค่ข้อความมา return classify_input_type(user_input, history=history_content_list) # Chat @observe(name="chat_flow") async def chat(self, user_input: str) -> str: history_before_current_input = self.history[:] self.append_history(HumanMessage(content=user_input)) try: input_type = self.classify_input(user_input) except Exception as e: print(f"Error classifying input type: {e}. Defaulting to Non-RAG.") input_type = "Non-RAG" ai_response_content = "" if input_type == "RAG": # print("[RAG FLOW]") ai_response_content = await self.call_rag(user_input) #, history_before_current_input) else: # print(f"[{input_type} FLOW (Treated as NON-RAG)]") ai_response_content = await self.call_non_rag(user_input) self.append_history(AIMessage(content=ai_response_content)) # print(f"AI:::: {ai_response_content}") # print(input_type) return ai_response_content @observe(name='rag_flow') async def call_rag(self, user_input: str) -> str: # main context context_docs = self.get_main_context(user_input, law_type="summary", chunk_type="section") # print(context_docs) main_context_str = self.format_main_context(context_docs) # print(main_context_str) # ref context ref_context_docs = self.get_ref_context(context_docs) try: ref_context_str = self.format_ref_context(ref_context_docs) except: ref_context_str = "-" history_for_llm_prompt = self.get_formatted_history_for_llm(n_turns=3) rag_input_data = { "question": user_input, "main_context": main_context_str, "ref_context": ref_context_str, "history": history_for_llm_prompt } try: prompt_messages = RAG_CHAT_PROMPT.format_messages(**rag_input_data) response = await self.llm_main.ainvoke(prompt_messages, config={"callbacks": [self.langfuse_handler]}) responsestring = response.content clean_response = re.sub(r"<[^>]+>", "", responsestring) clean_response = re.sub(r"#+", "", clean_response) clean_response = clean_response.strip() # return response.content.strip() return clean_response except Exception as e: print(f"Error during RAG LLM call: {e}") return "Sorry, I encountered an error while generating the response." @observe(name='non_rag_flow') async def call_non_rag(self, user_input: str) -> str: prompt_messages = NON_RAG_PROMPT.format(user_input=user_input) response = await self.llm_main.ainvoke(prompt_messages, config={"callbacks": [self.langfuse_handler]}) # ป้องกัน content เป็น None if not response or not response.content: return "ขออภัย ระบบไม่สามารถตอบคำถามได้ในขณะนี้" return response.content.strip()