import os import uuid from dotenv import load_dotenv from utils.chat_prompts import ( NON_RAG_PROMPT, RAG_CHAT_PROMPT_ENG, RAG_CHAT_PROMPT_TH, RAG_CHAT_PROMPT_KOREAN, QUERY_REWRITING_PROMPT_OBJ ) from get_retriever_2 import final_retrievers from input_classifier import classify_input_type, detect_language from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, AIMessage from langchain_google_genai import ChatGoogleGenerativeAI from langfuse.callback import CallbackHandler from langfuse.decorators import observe # Load environment variables from .env file load_dotenv() class Chat: def __init__(self, model_name_llm="jai-chat-1-3-2", temperature=0): self.session_id = str(uuid.uuid4())[:8] self.model_name_llm = model_name_llm self.langfuse_handler = CallbackHandler( secret_key=os.environ['LANGFUSE_SECRET_KEY'], public_key=os.environ['LANGFUSE_PUBLIC_KEY'], host="https://us.cloud.langfuse.com", session_id=self.session_id ) # --- LLM Initialization --- if model_name_llm == "jai-chat-1-3-2": self.llm_main = ChatOpenAI( model=model_name_llm, api_key=os.getenv("JAI_API_KEY"), base_url=os.getenv("CHAT_BASE_URL"), temperature=temperature, max_tokens=2048, max_retries=2, seed=13 ) self.llm_rewriter = self.llm_main elif model_name_llm == "gemini-2.0-flash": GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY") if not GEMINI_API_KEY: raise ValueError("GOOGLE_API_KEY (for Gemini) not found in environment variables.") common_gemini_config = { "google_api_key": GEMINI_API_KEY, "temperature": temperature, "max_output_tokens": 2048, "convert_system_message_to_human": True, } self.llm_main = ChatGoogleGenerativeAI( model="gemini-2.0-flash", **common_gemini_config ) self.llm_rewriter = ChatGoogleGenerativeAI( model="gemini-2.0-flash", **common_gemini_config ) else: raise ValueError(f"Unsupported LLM model '{model_name_llm}'.") self.history = [] # Store Langchain Message objects def append_history(self, message: [HumanMessage, AIMessage]): self.history.append(message) def get_formatted_history_for_llm(self, n_turns: int = 3) -> list: """Returns the last n_turns of history as a list of Message objects.""" return self.history[-(n_turns * 2):] def get_stringified_history_for_rewrite(self, n_turns: int = 2) -> str: """ Formats the last n_turns of history (excluding the current un-added user input) as a string for the query rewriter prompt. """ history_to_format = self.history[-(n_turns * 2):] if not history_to_format: return "No history available." history_str_parts = [] for msg in history_to_format: role = "User" if isinstance(msg, HumanMessage) else "AI" history_str_parts.append(f"{role}: {msg.content}") return "\n".join(history_str_parts) @observe(name="ClassifyInput") def classify_input(self, user_input: str) -> str: history_content_list = [msg.content for msg in self.history] return classify_input_type(user_input, history=history_content_list) def format_docs(self, docs: list) -> str: return "\n\n".join(doc.page_content for doc in docs) @observe(name="GetRetrieverandPrompt") def get_retriever_and_prompt(self, lang_code: str): """ Returns the appropriate retriever and RAG prompt based on the language. Handles potential errors if retriever or prompt is not found. """ retriever = final_retrievers.get(lang_code) if lang_code == "Thai": prompt_template = RAG_CHAT_PROMPT_TH elif lang_code == "Korean": prompt_template = RAG_CHAT_PROMPT_KOREAN elif lang_code == "English": prompt_template = RAG_CHAT_PROMPT_ENG else: print(f"Warning: Unsupported language '{lang_code}' for RAG. Defaulting to English.") retriever = final_retrievers.get('English') prompt_template = RAG_CHAT_PROMPT_ENG if not retriever: available_langs = list(final_retrievers.keys()) if available_langs: fallback_lang = available_langs[0] retriever = final_retrievers[fallback_lang] print(f"Warning: No retriever for '{lang_code}' or 'English'. Using first available: '{fallback_lang}'.") if fallback_lang == "Thai": prompt_template = RAG_CHAT_PROMPT_TH elif fallback_lang == "Korean": prompt_template = RAG_CHAT_PROMPT_KOREAN else: prompt_template = RAG_CHAT_PROMPT_ENG else: raise ValueError("CRITICAL: No retrievers configured at all.") if not prompt_template: raise ValueError(f"CRITICAL: No RAG prompt template found for language '{lang_code}' or effective fallback.") return retriever, prompt_template @observe(name="Non-Rag-Flow") def call_non_rag(self, user_input: str, input_lang: str) -> str: try: if hasattr(NON_RAG_PROMPT, "format_messages"): prompt_messages = NON_RAG_PROMPT.format(user_input=user_input, input_lang=input_lang) elif isinstance(NON_RAG_PROMPT, str): formatted_prompt_str = NON_RAG_PROMPT.format(user_input=user_input, input_lang=input_lang) prompt_messages = [HumanMessage(content=formatted_prompt_str)] else: raise TypeError("NON_RAG_PROMPT is of an unsupported type.") response = self.llm_main.invoke(prompt_messages, config={"callbacks": [self.langfuse_handler]}) return response.content.strip() except Exception as e: print(f"Error during Non-RAG LLM call: {e}") return "Sorry, I had trouble processing your general request." @observe(name="DetectLang") def _observe_detect_language(self, user_input: str) -> str: """Wraps the detect_language call for Langfuse observation.""" return detect_language(user_input) # If the main chat method itself should be a trace, uncomment @observe() below @observe(name="MainChatFlow") def chat(self, user_input: str) -> str: # print(f"\n\n-- USER INPUT: {user_input} --") try: # MODIFIED: Call the new observed method input_lang_detected = self._observe_detect_language(user_input) # print(f"Language detected: {input_lang_detected}") except Exception as e: print(f"Error detecting language: {e}. Defaulting to Thai.") input_lang_detected = "Thai" history_before_current_input = self.history[:] self.append_history(HumanMessage(content=user_input)) try: input_type = self.classify_input(user_input) except Exception as e: print(f"Error classifying input type: {e}. Defaulting to Non-RAG.") input_type = "Non-RAG" ai_response_content = "" if input_type == "RAG": # print("[RAG FLOW]") ai_response_content = self.call_rag_v2(user_input, input_lang_detected, history_before_current_input) else: # print(f"[{input_type} FLOW (Treated as NON-RAG)]") ai_response_content = self.call_non_rag(user_input, input_lang_detected) self.append_history(AIMessage(content=ai_response_content)) # # print(f"AI:::: {ai_response_content}") # return ai_response_content try: full_conversation = "\n".join( f"{'User' if isinstance(m, HumanMessage) else 'AI'}: {m.content}" for m in self.history ) self.langfuse_handler.update_current_trace(metadata={"full_conversation": full_conversation}) except Exception as e: print(f"Failed to update Langfuse metadata: {e}") return ai_response_content @observe(name="Rag-Flow") def call_rag_v2(self, user_input: str, input_lang: str, history_for_rewrite: list) -> str: try: retriever, selected_rag_prompt = self.get_retriever_and_prompt(input_lang) except ValueError as e: print(f"Error in RAG setup: {e}") return f"Sorry, I encountered a configuration issue for {input_lang} RAG. Please contact support." # --- Query Rewriting Step --- # MODIFIED: _rewrite_query_if_needed_v2 is now observed via its own decorator query_for_retriever = self._rewrite_query_if_needed_v2(user_input, history_for_rewrite) # print(f"Retrieving documents for query: '{query_for_retriever}' (lang: {input_lang})") try: context_docs = retriever.invoke(query_for_retriever) except Exception as e: print(f"Error during document retrieval: {e}") return "Sorry, I had trouble finding relevant information for your query." # print(f"Retrieved {len(context_docs)} documents.") context_str = self.format_docs(context_docs) # print(f"\n----> CONTEXT DOCS (from call_rag_v2)\n{context_str}") history_for_llm_prompt = self.get_formatted_history_for_llm(n_turns=3) rag_input_data = { "question": user_input, "context": context_str, "history": history_for_llm_prompt } try: prompt_messages = selected_rag_prompt.format_messages(**rag_input_data) response = self.llm_main.invoke(prompt_messages, config={"callbacks": [self.langfuse_handler]}) return response.content.strip() except Exception as e: print(f"Error during RAG LLM call: {e}") return "Sorry, I encountered an error while generating the response." @observe(name="RewriteQuery") def _rewrite_query_if_needed_v2(self, user_input: str, history_list: list) -> str: if not history_list: # self.langfuse_handler.trace(name="rewrite_query_skipped_no_history", input={"user_input": user_input}, output=user_input) return user_input history_str_parts = [] for msg in history_list[-(2*2):]: role = "User" if isinstance(msg, HumanMessage) else "AI" history_str_parts.append(f"{role}: {msg.content}") chat_history_str = "\n".join(history_str_parts) if history_str_parts else "No relevant history." try: rewrite_prompt_messages = QUERY_REWRITING_PROMPT_OBJ.format_messages( chat_history=chat_history_str, question=user_input ) response = self.llm_rewriter.invoke(rewrite_prompt_messages, config={"callbacks": [self.langfuse_handler]}) rewritten_query = response.content.strip() if rewritten_query and len(rewritten_query) < (len(user_input) + 250) and len(rewritten_query) > 0: # print(f"Original query: '{user_input}', Rewritten query for retriever: '{rewritten_query}'") return rewritten_query else: print(f"Rewritten query validation failed. Using original: '{user_input}'") # You could add a Langfuse event here if desired # self.langfuse_handler.score(name="rewrite_validation_failed", value=0, comment="Rewritten query failed validation") return user_input except Exception as e: print(f"Error during query rewriting: {e}. Using original query.") # self.langfuse_handler.score(name="rewrite_error", value=0, comment=str(e)) return user_input