poc_hb / chat_3.py
Ing's picture
name observation process
7b5c91e
import os
import uuid
from dotenv import load_dotenv
from utils.chat_prompts import (
NON_RAG_PROMPT,
RAG_CHAT_PROMPT_ENG,
RAG_CHAT_PROMPT_TH,
RAG_CHAT_PROMPT_KOREAN,
QUERY_REWRITING_PROMPT_OBJ
)
from get_retriever_2 import final_retrievers
from input_classifier import classify_input_type, detect_language
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langfuse.callback import CallbackHandler
from langfuse.decorators import observe
# Load environment variables from .env file
load_dotenv()
class Chat:
def __init__(self, model_name_llm="jai-chat-1-3-2", temperature=0):
self.session_id = str(uuid.uuid4())[:8]
self.model_name_llm = model_name_llm
self.langfuse_handler = CallbackHandler(
secret_key=os.environ['LANGFUSE_SECRET_KEY'],
public_key=os.environ['LANGFUSE_PUBLIC_KEY'],
host="https://us.cloud.langfuse.com",
session_id=self.session_id
)
# --- LLM Initialization ---
if model_name_llm == "jai-chat-1-3-2":
self.llm_main = ChatOpenAI(
model=model_name_llm,
api_key=os.getenv("JAI_API_KEY"),
base_url=os.getenv("CHAT_BASE_URL"),
temperature=temperature,
max_tokens=2048,
max_retries=2,
seed=13
)
self.llm_rewriter = self.llm_main
elif model_name_llm == "gemini-2.0-flash":
GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
if not GEMINI_API_KEY:
raise ValueError("GOOGLE_API_KEY (for Gemini) not found in environment variables.")
common_gemini_config = {
"google_api_key": GEMINI_API_KEY,
"temperature": temperature,
"max_output_tokens": 2048,
"convert_system_message_to_human": True,
}
self.llm_main = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
**common_gemini_config
)
self.llm_rewriter = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
**common_gemini_config
)
else:
raise ValueError(f"Unsupported LLM model '{model_name_llm}'.")
self.history = [] # Store Langchain Message objects
def append_history(self, message: [HumanMessage, AIMessage]):
self.history.append(message)
def get_formatted_history_for_llm(self, n_turns: int = 3) -> list:
"""Returns the last n_turns of history as a list of Message objects."""
return self.history[-(n_turns * 2):]
def get_stringified_history_for_rewrite(self, n_turns: int = 2) -> str:
"""
Formats the last n_turns of history (excluding the current un-added user input)
as a string for the query rewriter prompt.
"""
history_to_format = self.history[-(n_turns * 2):]
if not history_to_format:
return "No history available."
history_str_parts = []
for msg in history_to_format:
role = "User" if isinstance(msg, HumanMessage) else "AI"
history_str_parts.append(f"{role}: {msg.content}")
return "\n".join(history_str_parts)
@observe(name="ClassifyInput")
def classify_input(self, user_input: str) -> str:
history_content_list = [msg.content for msg in self.history]
return classify_input_type(user_input, history=history_content_list)
def format_docs(self, docs: list) -> str:
return "\n\n".join(doc.page_content for doc in docs)
@observe(name="GetRetrieverandPrompt")
def get_retriever_and_prompt(self, lang_code: str):
"""
Returns the appropriate retriever and RAG prompt based on the language.
Handles potential errors if retriever or prompt is not found.
"""
retriever = final_retrievers.get(lang_code)
if lang_code == "Thai":
prompt_template = RAG_CHAT_PROMPT_TH
elif lang_code == "Korean":
prompt_template = RAG_CHAT_PROMPT_KOREAN
elif lang_code == "English":
prompt_template = RAG_CHAT_PROMPT_ENG
else:
print(f"Warning: Unsupported language '{lang_code}' for RAG. Defaulting to English.")
retriever = final_retrievers.get('English')
prompt_template = RAG_CHAT_PROMPT_ENG
if not retriever:
available_langs = list(final_retrievers.keys())
if available_langs:
fallback_lang = available_langs[0]
retriever = final_retrievers[fallback_lang]
print(f"Warning: No retriever for '{lang_code}' or 'English'. Using first available: '{fallback_lang}'.")
if fallback_lang == "Thai": prompt_template = RAG_CHAT_PROMPT_TH
elif fallback_lang == "Korean": prompt_template = RAG_CHAT_PROMPT_KOREAN
else: prompt_template = RAG_CHAT_PROMPT_ENG
else:
raise ValueError("CRITICAL: No retrievers configured at all.")
if not prompt_template:
raise ValueError(f"CRITICAL: No RAG prompt template found for language '{lang_code}' or effective fallback.")
return retriever, prompt_template
@observe(name="Non-Rag-Flow")
def call_non_rag(self, user_input: str, input_lang: str) -> str:
try:
if hasattr(NON_RAG_PROMPT, "format_messages"):
prompt_messages = NON_RAG_PROMPT.format(user_input=user_input, input_lang=input_lang)
elif isinstance(NON_RAG_PROMPT, str):
formatted_prompt_str = NON_RAG_PROMPT.format(user_input=user_input, input_lang=input_lang)
prompt_messages = [HumanMessage(content=formatted_prompt_str)]
else:
raise TypeError("NON_RAG_PROMPT is of an unsupported type.")
response = self.llm_main.invoke(prompt_messages, config={"callbacks": [self.langfuse_handler]})
return response.content.strip()
except Exception as e:
print(f"Error during Non-RAG LLM call: {e}")
return "Sorry, I had trouble processing your general request."
@observe(name="DetectLang")
def _observe_detect_language(self, user_input: str) -> str:
"""Wraps the detect_language call for Langfuse observation."""
return detect_language(user_input)
# If the main chat method itself should be a trace, uncomment @observe() below
@observe(name="MainChatFlow")
def chat(self, user_input: str) -> str:
# print(f"\n\n-- USER INPUT: {user_input} --")
try:
# MODIFIED: Call the new observed method
input_lang_detected = self._observe_detect_language(user_input)
# print(f"Language detected: {input_lang_detected}")
except Exception as e:
print(f"Error detecting language: {e}. Defaulting to Thai.")
input_lang_detected = "Thai"
history_before_current_input = self.history[:]
self.append_history(HumanMessage(content=user_input))
try:
input_type = self.classify_input(user_input)
except Exception as e:
print(f"Error classifying input type: {e}. Defaulting to Non-RAG.")
input_type = "Non-RAG"
ai_response_content = ""
if input_type == "RAG":
# print("[RAG FLOW]")
ai_response_content = self.call_rag_v2(user_input, input_lang_detected, history_before_current_input)
else:
# print(f"[{input_type} FLOW (Treated as NON-RAG)]")
ai_response_content = self.call_non_rag(user_input, input_lang_detected)
self.append_history(AIMessage(content=ai_response_content))
# # print(f"AI:::: {ai_response_content}")
# return ai_response_content
try:
full_conversation = "\n".join(
f"{'User' if isinstance(m, HumanMessage) else 'AI'}: {m.content}"
for m in self.history
)
self.langfuse_handler.update_current_trace(metadata={"full_conversation": full_conversation})
except Exception as e:
print(f"Failed to update Langfuse metadata: {e}")
return ai_response_content
@observe(name="Rag-Flow")
def call_rag_v2(self, user_input: str, input_lang: str, history_for_rewrite: list) -> str:
try:
retriever, selected_rag_prompt = self.get_retriever_and_prompt(input_lang)
except ValueError as e:
print(f"Error in RAG setup: {e}")
return f"Sorry, I encountered a configuration issue for {input_lang} RAG. Please contact support."
# --- Query Rewriting Step ---
# MODIFIED: _rewrite_query_if_needed_v2 is now observed via its own decorator
query_for_retriever = self._rewrite_query_if_needed_v2(user_input, history_for_rewrite)
# print(f"Retrieving documents for query: '{query_for_retriever}' (lang: {input_lang})")
try:
context_docs = retriever.invoke(query_for_retriever)
except Exception as e:
print(f"Error during document retrieval: {e}")
return "Sorry, I had trouble finding relevant information for your query."
# print(f"Retrieved {len(context_docs)} documents.")
context_str = self.format_docs(context_docs)
# print(f"\n----> CONTEXT DOCS (from call_rag_v2)\n{context_str}")
history_for_llm_prompt = self.get_formatted_history_for_llm(n_turns=3)
rag_input_data = {
"question": user_input,
"context": context_str,
"history": history_for_llm_prompt
}
try:
prompt_messages = selected_rag_prompt.format_messages(**rag_input_data)
response = self.llm_main.invoke(prompt_messages, config={"callbacks": [self.langfuse_handler]})
return response.content.strip()
except Exception as e:
print(f"Error during RAG LLM call: {e}")
return "Sorry, I encountered an error while generating the response."
@observe(name="RewriteQuery")
def _rewrite_query_if_needed_v2(self, user_input: str, history_list: list) -> str:
if not history_list:
# self.langfuse_handler.trace(name="rewrite_query_skipped_no_history", input={"user_input": user_input}, output=user_input)
return user_input
history_str_parts = []
for msg in history_list[-(2*2):]:
role = "User" if isinstance(msg, HumanMessage) else "AI"
history_str_parts.append(f"{role}: {msg.content}")
chat_history_str = "\n".join(history_str_parts) if history_str_parts else "No relevant history."
try:
rewrite_prompt_messages = QUERY_REWRITING_PROMPT_OBJ.format_messages(
chat_history=chat_history_str,
question=user_input
)
response = self.llm_rewriter.invoke(rewrite_prompt_messages, config={"callbacks": [self.langfuse_handler]})
rewritten_query = response.content.strip()
if rewritten_query and len(rewritten_query) < (len(user_input) + 250) and len(rewritten_query) > 0:
# print(f"Original query: '{user_input}', Rewritten query for retriever: '{rewritten_query}'")
return rewritten_query
else:
print(f"Rewritten query validation failed. Using original: '{user_input}'")
# You could add a Langfuse event here if desired
# self.langfuse_handler.score(name="rewrite_validation_failed", value=0, comment="Rewritten query failed validation")
return user_input
except Exception as e:
print(f"Error during query rewriting: {e}. Using original query.")
# self.langfuse_handler.score(name="rewrite_error", value=0, comment=str(e))
return user_input