|
|
|
from dotenv import load_dotenv |
|
import os |
|
from utils.chat_prompts import CLASSIFICATION_INPUT_PROMPT |
|
|
|
from openai import OpenAI |
|
import random, time |
|
|
|
load_dotenv() |
|
|
|
client_jai = OpenAI( |
|
api_key=os.environ.get("JAI_API_KEY"), |
|
base_url=os.environ.get("CHAT_BASE_URL") |
|
) |
|
model = "jai-chat-1-3-2" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_content_with_retry(client, model, prompt, max_retries=3): |
|
""" |
|
Helper function to call client.models.generate_content with retry logic. |
|
""" |
|
for attempt in range(max_retries): |
|
try: |
|
|
|
response = client.chat.completions.create( |
|
model=model, |
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
|
|
) |
|
|
|
return response.choices[0].message.content |
|
|
|
|
|
except Exception as e: |
|
if hasattr(e, 'code') and e.code == 503: |
|
print(f"Attempt {attempt + 1} failed with 503 error: {e}") |
|
wait_time = (2 ** attempt) + random.random() |
|
print(f"Waiting {wait_time:.2f} seconds before retrying...") |
|
time.sleep(wait_time) |
|
else: |
|
print(f"Attempt {attempt + 1} failed with a different error: {e}") |
|
raise |
|
|
|
print(f"Failed to generate content after {max_retries} retries.") |
|
return None |
|
|
|
|
|
def classify_input_type(user_input: str, history: list = None) -> str: |
|
""" |
|
Classifies the user input as 'RAG' or 'Non-RAG' using the LLM, considering chat history. |
|
Supports history as a list of strings or a list of dicts with 'type' and 'content'. |
|
""" |
|
history_text = "None" |
|
|
|
if history: |
|
formatted_history = [] |
|
|
|
for i, msg in enumerate(history[-3:]): |
|
|
|
if isinstance(msg, dict): |
|
role = "Human" if msg.get("type") == "human" else "AI" |
|
content = msg.get("content", "") |
|
|
|
elif isinstance(msg, str): |
|
role = "Human" if (len(history[-3:]) - i) % 2 == 1 else "AI" |
|
content = msg |
|
else: |
|
continue |
|
|
|
formatted_history.append(f"{role}: {content}") |
|
|
|
history_text = "\n".join(formatted_history) |
|
|
|
formatted_messages = CLASSIFICATION_INPUT_PROMPT.format( |
|
user_input=user_input, |
|
chat_history=history_text |
|
) |
|
|
|
prompt_content = formatted_messages |
|
|
|
|
|
result = generate_content_with_retry(client_jai, model, prompt_content) |
|
|
|
if result is None: |
|
raise Exception("Failed to classify input type after multiple retries.") |
|
|
|
return result |