law_poc / utils /input_classifier.py
SUMANA SUMANAKUL (ING)
commit
8e5a9dd
raw
history blame
3.32 kB
# from openai import OpenAI
from dotenv import load_dotenv
import os
from utils.chat_prompts import CLASSIFICATION_INPUT_PROMPT
# from google import genai
from openai import OpenAI
import random, time
load_dotenv()
client_jai = OpenAI(
api_key=os.environ.get("JAI_API_KEY"),
base_url=os.environ.get("CHAT_BASE_URL")
)
model = "jai-chat-1-3-2"
# model = "openthaigpt72b"
# gemi = os.environ["GEMINI_API_KEY"]
# client_jai = genai.Client(api_key=gemi)
# model = "gemini-2.0-flash"
# temperature = 0.0
def generate_content_with_retry(client, model, prompt, max_retries=3):
"""
Helper function to call client.models.generate_content with retry logic.
"""
for attempt in range(max_retries):
try:
# response = client.models.generate_content( # gemi
response = client.chat.completions.create(
model=model,
messages = [{"role": "user", "content": prompt}]
# contents=prompt, # gemi
# temperature=temperature, # Optional: Restore if needed
)
# return response.text.strip() # Return the result if successful # gemi
return response.choices[0].message.content
except Exception as e:
if hasattr(e, 'code') and e.code == 503: # Check for the 503 error
print(f"Attempt {attempt + 1} failed with 503 error: {e}")
wait_time = (2 ** attempt) + random.random() # Exponential backoff
print(f"Waiting {wait_time:.2f} seconds before retrying...")
time.sleep(wait_time)
else:
print(f"Attempt {attempt + 1} failed with a different error: {e}")
raise # Re-raise the exception if it's not a 503
print(f"Failed to generate content after {max_retries} retries.")
return None # Or raise an exception, depending on desired behavior
def classify_input_type(user_input: str, history: list = None) -> str:
"""
Classifies the user input as 'RAG' or 'Non-RAG' using the LLM, considering chat history.
Supports history as a list of strings or a list of dicts with 'type' and 'content'.
"""
history_text = "None"
if history:
formatted_history = []
for i, msg in enumerate(history[-3:]):
# Case: history is list of dicts
if isinstance(msg, dict):
role = "Human" if msg.get("type") == "human" else "AI" # changed dash to human and aie
content = msg.get("content", "")
# Case: history is list of strings, alternate roles
elif isinstance(msg, str):
role = "Human" if (len(history[-3:]) - i) % 2 == 1 else "AI"
content = msg
else:
continue # skip invalid entry
formatted_history.append(f"{role}: {content}")
history_text = "\n".join(formatted_history)
formatted_messages = CLASSIFICATION_INPUT_PROMPT.format(
user_input=user_input,
chat_history=history_text
)
prompt_content = formatted_messages
# print(prompt_content)
result = generate_content_with_retry(client_jai, model, prompt_content)
if result is None:
raise Exception("Failed to classify input type after multiple retries.")
return result