File size: 3,323 Bytes
8e5a9dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
# from openai import OpenAI
from dotenv import load_dotenv
import os
from utils.chat_prompts import CLASSIFICATION_INPUT_PROMPT
# from google import genai
from openai import OpenAI
import random, time
load_dotenv()
client_jai = OpenAI(
api_key=os.environ.get("JAI_API_KEY"),
base_url=os.environ.get("CHAT_BASE_URL")
)
model = "jai-chat-1-3-2"
# model = "openthaigpt72b"
# gemi = os.environ["GEMINI_API_KEY"]
# client_jai = genai.Client(api_key=gemi)
# model = "gemini-2.0-flash"
# temperature = 0.0
def generate_content_with_retry(client, model, prompt, max_retries=3):
"""
Helper function to call client.models.generate_content with retry logic.
"""
for attempt in range(max_retries):
try:
# response = client.models.generate_content( # gemi
response = client.chat.completions.create(
model=model,
messages = [{"role": "user", "content": prompt}]
# contents=prompt, # gemi
# temperature=temperature, # Optional: Restore if needed
)
# return response.text.strip() # Return the result if successful # gemi
return response.choices[0].message.content
except Exception as e:
if hasattr(e, 'code') and e.code == 503: # Check for the 503 error
print(f"Attempt {attempt + 1} failed with 503 error: {e}")
wait_time = (2 ** attempt) + random.random() # Exponential backoff
print(f"Waiting {wait_time:.2f} seconds before retrying...")
time.sleep(wait_time)
else:
print(f"Attempt {attempt + 1} failed with a different error: {e}")
raise # Re-raise the exception if it's not a 503
print(f"Failed to generate content after {max_retries} retries.")
return None # Or raise an exception, depending on desired behavior
def classify_input_type(user_input: str, history: list = None) -> str:
"""
Classifies the user input as 'RAG' or 'Non-RAG' using the LLM, considering chat history.
Supports history as a list of strings or a list of dicts with 'type' and 'content'.
"""
history_text = "None"
if history:
formatted_history = []
for i, msg in enumerate(history[-3:]):
# Case: history is list of dicts
if isinstance(msg, dict):
role = "Human" if msg.get("type") == "human" else "AI" # changed dash to human and aie
content = msg.get("content", "")
# Case: history is list of strings, alternate roles
elif isinstance(msg, str):
role = "Human" if (len(history[-3:]) - i) % 2 == 1 else "AI"
content = msg
else:
continue # skip invalid entry
formatted_history.append(f"{role}: {content}")
history_text = "\n".join(formatted_history)
formatted_messages = CLASSIFICATION_INPUT_PROMPT.format(
user_input=user_input,
chat_history=history_text
)
prompt_content = formatted_messages
# print(prompt_content)
result = generate_content_with_retry(client_jai, model, prompt_content)
if result is None:
raise Exception("Failed to classify input type after multiple retries.")
return result |