|
import os, re |
|
|
|
os.environ["OTEL_SDK_DISABLED"] = "true" |
|
|
|
import uuid |
|
from dotenv import load_dotenv |
|
|
|
from utils.chat_prompts import RAG_CHAT_PROMPT, NON_RAG_PROMPT |
|
from utils.reranker import RerankRetriever |
|
from utils.input_classifier import classify_input_type |
|
|
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.messages import HumanMessage, AIMessage |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
|
from pymongo import MongoClient |
|
|
|
from langfuse.langchain import CallbackHandler |
|
from langfuse import observe |
|
|
|
load_dotenv() |
|
|
|
|
|
mongo_username = os.environ.get('MONGO_USERNAME') |
|
mongo_password = os.environ.get('MONGO_PASSWORD') |
|
mongo_database = os.environ.get('MONGO_DATABASE') |
|
mongo_connection_str = os.environ.get('MONGO_CONNECTION_STRING') |
|
mongo_collection_name = os.environ.get('MONGO_COLLECTION') |
|
|
|
class ChatLaborLaw: |
|
def __init__(self, model_name_llm="jai-chat-1-3-2", temperature=0): |
|
self.session_id = str(uuid.uuid4())[:8] |
|
|
|
|
|
self.langfuse_handler = CallbackHandler( |
|
) |
|
|
|
self.history = [] |
|
|
|
self.model_name_llm = model_name_llm |
|
self.retriever = RerankRetriever() |
|
|
|
self.client = MongoClient(mongo_connection_str) |
|
self.db = self.client[mongo_database] |
|
self.collection = self.db[mongo_collection_name] |
|
|
|
|
|
if model_name_llm == "jai-chat-1-3-2": |
|
self.llm_main = ChatOpenAI( |
|
model=model_name_llm, |
|
api_key=os.getenv("JAI_API_KEY"), |
|
base_url=os.getenv("CHAT_BASE_URL"), |
|
temperature=temperature, |
|
max_tokens=2048, |
|
max_retries=2, |
|
seed=13 |
|
) |
|
|
|
elif model_name_llm == "gemini-2.0-flash": |
|
GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
if not GEMINI_API_KEY: |
|
raise ValueError("GOOGLE_API_KEY (for Gemini) not found in environment variables.") |
|
|
|
common_gemini_config = { |
|
"google_api_key": GEMINI_API_KEY, |
|
"temperature": temperature, |
|
"max_output_tokens": 2048, |
|
"convert_system_message_to_human": True, |
|
} |
|
self.llm_main = ChatGoogleGenerativeAI( |
|
model="gemini-2.0-flash", |
|
**common_gemini_config) |
|
|
|
else: |
|
raise ValueError(f"Unsupported LLM model '{model_name_llm}'.") |
|
|
|
self.history = [] |
|
|
|
|
|
|
|
@observe(name='main_context') |
|
def get_main_context(self, user_query, **kwargs): |
|
|
|
|
|
compression_retriever = self.retriever.get_compression_retriever(**kwargs) |
|
main_comtext_docs = compression_retriever.invoke(user_query) |
|
return main_comtext_docs |
|
|
|
@observe(name='ref_context') |
|
def get_ref_context(self, main_context_docs): |
|
""" |
|
ค้นหา Context ของมาตราที่ถูกอ้างอิงจาก MongoDB |
|
โดยใช้ $in operator เพื่อประสิทธิภาพสูงสุด |
|
""" |
|
all_reference_docs = [] |
|
|
|
for context in main_context_docs: |
|
references_list = context.metadata.get('references', []) |
|
|
|
if not isinstance(references_list, list) or not references_list: |
|
continue |
|
|
|
ref_numbers = [ |
|
ref_str.replace("มาตรา", "").strip() |
|
for ref_str in references_list |
|
] |
|
|
|
|
|
mongo_query = { |
|
"law_type": "summary", |
|
"section_number": {"$in": ref_numbers} |
|
} |
|
|
|
projection = { |
|
"_id": 1, |
|
"text": 1, |
|
"document_type": 1, |
|
"law_type": 1, |
|
"law_name": 1, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"chunk_type": 1, |
|
"section_number": 1 |
|
} |
|
|
|
results = self.collection.find(mongo_query, projection) |
|
all_reference_docs.extend(list(results)) |
|
|
|
|
|
ref_docs_by_id = {} |
|
for doc in all_reference_docs: |
|
ref_docs_by_id[doc["_id"]] = doc |
|
|
|
return list(ref_docs_by_id.values()) |
|
|
|
|
|
|
|
|
|
def format_main_context(self, list_of_documents): |
|
""" |
|
input: list of Document (Langchain) |
|
output: text --> to forward to prompt |
|
""" |
|
formatted_docs = [] |
|
|
|
for i, doc in enumerate(list_of_documents): |
|
law_name = doc.metadata.get('law_name', '-') |
|
section_number = doc.metadata.get('section_number', '-') |
|
publication_date = doc.metadata.get('publication_date', '-') |
|
effective_date = doc.metadata.get('effective_date', '-') |
|
content = doc.page_content |
|
|
|
formatted = "\n".join([ |
|
f"Doc{i}", |
|
f"{law_name}", |
|
f"มาตรา\t{section_number}", |
|
content, |
|
f"ประกาศ\t{publication_date}", |
|
f"เริ่มใช้\t{effective_date}" |
|
]) |
|
|
|
formatted_docs.append(formatted) |
|
|
|
return "\n\n".join(formatted_docs) |
|
|
|
|
|
def format_ref_context(self, list_of_docs): |
|
formatted_ref_docs = [] |
|
|
|
for i, doc in enumerate(list_of_docs): |
|
law_name = doc.get('law_name', '-') |
|
section_number = doc.get('section_number', '-') |
|
content = doc.get('text', '-') |
|
|
|
formatted = "\n".join([ |
|
f"{law_name}", |
|
f"มาตรา\t{section_number}", |
|
content, |
|
]) |
|
formatted_ref_docs.append(formatted) |
|
|
|
return "\n\n".join(formatted_ref_docs) |
|
|
|
|
|
|
|
|
|
|
|
def append_history(self, message: [HumanMessage, AIMessage]): |
|
self.history.append(message) |
|
|
|
def get_formatted_history_for_llm(self, n_turns: int = 3) -> list: |
|
"""Returns the last n_turns of history as a list of Message objects.""" |
|
return self.history[-(n_turns * 2):] |
|
|
|
|
|
|
|
@observe(name='classify_input_type') |
|
def classify_input(self, user_input: str) -> str: |
|
history_content_list = [msg.content for msg in self.history] |
|
return classify_input_type(user_input, history=history_content_list) |
|
|
|
|
|
|
|
@observe(name="chat_flow") |
|
async def chat(self, user_input: str) -> str: |
|
history_before_current_input = self.history[:] |
|
self.append_history(HumanMessage(content=user_input)) |
|
|
|
try: |
|
input_type = self.classify_input(user_input) |
|
except Exception as e: |
|
print(f"Error classifying input type: {e}. Defaulting to Non-RAG.") |
|
input_type = "Non-RAG" |
|
|
|
ai_response_content = "" |
|
if input_type == "RAG": |
|
|
|
ai_response_content = await self.call_rag(user_input) |
|
else: |
|
|
|
ai_response_content = await self.call_non_rag(user_input) |
|
|
|
self.append_history(AIMessage(content=ai_response_content)) |
|
|
|
|
|
return ai_response_content |
|
|
|
|
|
@observe(name='rag_flow') |
|
async def call_rag(self, user_input: str) -> str: |
|
|
|
|
|
context_docs = self.get_main_context(user_input, law_type="summary") |
|
|
|
main_context_str = self.format_main_context(context_docs) |
|
|
|
|
|
ref_context_docs = self.get_ref_context(context_docs) |
|
try: |
|
ref_context_str = self.format_ref_context(ref_context_docs) |
|
except: |
|
ref_context_str = "-" |
|
|
|
history_for_llm_prompt = self.get_formatted_history_for_llm(n_turns=3) |
|
|
|
rag_input_data = { |
|
"question": user_input, |
|
"main_context": main_context_str, |
|
"ref_context": ref_context_str, |
|
"history": history_for_llm_prompt |
|
} |
|
|
|
try: |
|
prompt_messages = RAG_CHAT_PROMPT.format_messages(**rag_input_data) |
|
response = await self.llm_main.ainvoke(prompt_messages, config={"callbacks": [self.langfuse_handler]}) |
|
responsestring = response.content |
|
clean_response = re.sub(r"<[^>]+>", "", responsestring) |
|
clean_response = re.sub(r"#+", "", clean_response) |
|
clean_response = clean_response.strip() |
|
|
|
return clean_response |
|
|
|
except Exception as e: |
|
print(f"Error during RAG LLM call: {e}") |
|
return "Sorry, I encountered an error while generating the response." |
|
|
|
@observe(name='non_rag_flow') |
|
async def call_non_rag(self, user_input: str) -> str: |
|
prompt_messages = NON_RAG_PROMPT.format(user_input=user_input) |
|
response = await self.llm_main.ainvoke(prompt_messages, config={"callbacks": [self.langfuse_handler]}) |
|
|
|
|
|
if not response or not response.content: |
|
return "ขออภัย ระบบไม่สามารถตอบคำถามได้ในขณะนี้" |
|
|
|
return response.content.strip() |