File size: 2,780 Bytes
9b674e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from openai import OpenAI
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq

try:
    from .utils.db import load_api_key, load_openai_url, load_model_settings, load_groq_api_key, load_google_api_key
    from .custom_callback import customcallback
except ImportError:
    from utils.db import load_api_key, load_openai_url, load_model_settings, load_groq_api_key, load_google_api_key
    from custom_callback import customcallback



the_callback = customcallback(strip_tokens=False, answer_prefix_tokens=["Answer"])



def get_model(high_context=False):
    the_model = load_model_settings()
    the_api_key = load_api_key()
    the_groq_api_key = load_groq_api_key()
    the_google_api_key = load_google_api_key()
    the_openai_url = load_openai_url()

    def open_ai_base(high_context):
        if the_openai_url == "default":
            true_model = the_model
            if high_context:
                true_model = "gpt-4-turbo"
            return {"model": true_model, "api_key": the_api_key, "max_retries":15, "streaming":True, "callbacks":[the_callback]}
        else:
            return {"model": the_model, "api_key": the_api_key, "max_retries":15, "streaming":True, "callbacks":[the_callback], "base_url": the_openai_url}

    args_mapping = {
        ChatOpenAI: open_ai_base(high_context=high_context),
        ChatOllama: {"model": the_model},
        ChatGroq: {"temperature": 0, "model_name": the_model.replace("-groq", ""), "groq_api_key": the_openai_url},
        ChatGoogleGenerativeAI:{"model": the_model, "google_api_key": the_google_api_key}
    }
    model_mapping = {
        # OpenAI
        "gpt-4o": (ChatOpenAI, args_mapping[ChatOpenAI]),
        "gpt-4-turbo": (ChatOpenAI, args_mapping[ChatOpenAI]),
        "gpt-3.5": (ChatOpenAI, args_mapping[ChatOpenAI]),
        "gpt-3.5-turbo": (ChatOpenAI, args_mapping[ChatOpenAI]),

        # Google Generative AI - Llama
        "llava": (ChatOllama, args_mapping[ChatOllama]),
        "llama3": (ChatOllama, args_mapping[ChatOllama]),
        "bakllava": (ChatOllama, args_mapping[ChatOllama]),

        # Google Generative AI - Gemini
        "gemini-pro": (ChatGoogleGenerativeAI, args_mapping[ChatGoogleGenerativeAI]),

        # Groq
        "mixtral-8x7b-groq": (ChatGroq, args_mapping[ChatGroq])
    }

    model_class, args = model_mapping[the_model]
    return model_class(**args) if model_class else None


def get_client():
    the_api_key = load_api_key()
    the_openai_url = load_openai_url()
    if the_openai_url == "default":
        return OpenAI(api_key=the_api_key)
    else:
        return OpenAI(api_key=the_api_key, base_url=the_openai_url)