File size: 2,970 Bytes
f5ab4cb
 
 
 
 
 
 
 
 
 
 
 
3c2fc33
 
 
 
 
 
cd47483
3c2fc33
 
f5ab4cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM

from synthetic_dataset_generator.constants import (
    API_KEYS,
    HUGGINGFACE_BASE_URL,
    MAGPIE_PRE_QUERY_TEMPLATE,
    MODEL,
    OLLAMA_BASE_URL,
    OPENAI_BASE_URL,
    TOKENIZER_ID,
)

TOKEN_INDEX = 0


def _get_next_api_key():
    global TOKEN_INDEX
    api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
    TOKEN_INDEX += 1
    return api_key


def _get_llm(use_magpie_template=False, **kwargs):
    if OPENAI_BASE_URL:
        llm = OpenAILLM(
            model=MODEL,
            base_url=OPENAI_BASE_URL,
            api_key=_get_next_api_key(),
            **kwargs,
        )
        if "generation_kwargs" in kwargs:
            if "stop_sequences" in kwargs["generation_kwargs"]:
                kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
                    "stop_sequences"
                ]
                del kwargs["generation_kwargs"]["stop_sequences"]
            if "do_sample" in kwargs["generation_kwargs"]:
                del kwargs["generation_kwargs"]["do_sample"]
    elif OLLAMA_BASE_URL:
        if "generation_kwargs" in kwargs:
            if "max_new_tokens" in kwargs["generation_kwargs"]:
                kwargs["generation_kwargs"]["num_predict"] = kwargs[
                    "generation_kwargs"
                ]["max_new_tokens"]
                del kwargs["generation_kwargs"]["max_new_tokens"]
            if "stop_sequences" in kwargs["generation_kwargs"]:
                kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
                    "stop_sequences"
                ]
                del kwargs["generation_kwargs"]["stop_sequences"]
            if "do_sample" in kwargs["generation_kwargs"]:
                del kwargs["generation_kwargs"]["do_sample"]
            options = kwargs["generation_kwargs"]
            del kwargs["generation_kwargs"]
            kwargs["generation_kwargs"] = {}
            kwargs["generation_kwargs"]["options"] = options
        llm = OllamaLLM(
            model=MODEL,
            host=OLLAMA_BASE_URL,
            tokenizer_id=TOKENIZER_ID or MODEL,
            **kwargs,
        )
    elif HUGGINGFACE_BASE_URL:
        kwargs["generation_kwargs"]["do_sample"] = True
        llm = InferenceEndpointsLLM(
            api_key=_get_next_api_key(),
            base_url=HUGGINGFACE_BASE_URL,
            tokenizer_id=TOKENIZER_ID or MODEL,
            **kwargs,
        )
    else:
        llm = InferenceEndpointsLLM(
            api_key=_get_next_api_key(),
            tokenizer_id=TOKENIZER_ID or MODEL,
            model_id=MODEL,
            magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
            **kwargs,
        )

    return llm


try:
    llm = _get_llm()
    llm.load()
    llm.generate([[{"content": "Hello, world!", "role": "user"}]])
except Exception as e:
    gr.Error(f"Error loading {llm.__class__.__name__}: {e}")