File size: 3,323 Bytes
8e5a9dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# from openai import OpenAI
from dotenv import load_dotenv
import os
from utils.chat_prompts import CLASSIFICATION_INPUT_PROMPT
# from google import genai
from openai import OpenAI
import random, time

load_dotenv()

client_jai = OpenAI(
    api_key=os.environ.get("JAI_API_KEY"),
    base_url=os.environ.get("CHAT_BASE_URL")
)
model = "jai-chat-1-3-2"
# model = "openthaigpt72b"

# gemi = os.environ["GEMINI_API_KEY"]
# client_jai = genai.Client(api_key=gemi)
# model = "gemini-2.0-flash"
# temperature = 0.0

def generate_content_with_retry(client, model, prompt, max_retries=3):
    """
    Helper function to call client.models.generate_content with retry logic.
    """
    for attempt in range(max_retries):
        try:
            # response = client.models.generate_content( # gemi
            response = client.chat.completions.create(
                model=model,
                messages = [{"role": "user", "content": prompt}]
                # contents=prompt, # gemi
                # temperature=temperature,  # Optional:  Restore if needed
            )
            # return response.text.strip()  # Return the result if successful # gemi
            return response.choices[0].message.content


        except Exception as e:
            if hasattr(e, 'code') and e.code == 503: # Check for the 503 error
                print(f"Attempt {attempt + 1} failed with 503 error: {e}")
                wait_time = (2 ** attempt) + random.random()  # Exponential backoff
                print(f"Waiting {wait_time:.2f} seconds before retrying...")
                time.sleep(wait_time)
            else:
                print(f"Attempt {attempt + 1} failed with a different error: {e}")
                raise  # Re-raise the exception if it's not a 503

    print(f"Failed to generate content after {max_retries} retries.")
    return None # Or raise an exception, depending on desired behavior


def classify_input_type(user_input: str, history: list = None) -> str:
    """
    Classifies the user input as 'RAG' or 'Non-RAG' using the LLM, considering chat history.
    Supports history as a list of strings or a list of dicts with 'type' and 'content'.
    """
    history_text = "None"

    if history:
        formatted_history = []

        for i, msg in enumerate(history[-3:]):
            # Case: history is list of dicts
            if isinstance(msg, dict):
                role = "Human" if msg.get("type") == "human" else "AI" # changed dash to human and aie
                content = msg.get("content", "")
            # Case: history is list of strings, alternate roles
            elif isinstance(msg, str):
                role = "Human" if (len(history[-3:]) - i) % 2 == 1 else "AI"
                content = msg
            else:
                continue  # skip invalid entry

            formatted_history.append(f"{role}: {content}")

        history_text = "\n".join(formatted_history)

    formatted_messages = CLASSIFICATION_INPUT_PROMPT.format(
        user_input=user_input,
        chat_history=history_text
    )

    prompt_content = formatted_messages
    # print(prompt_content)

    result = generate_content_with_retry(client_jai, model, prompt_content)

    if result is None:
        raise Exception("Failed to classify input type after multiple retries.")

    return result