File size: 4,914 Bytes
cd47483
 
 
 
 
f5ab4cb
 
 
 
 
6a5179a
8291c8c
 
f5ab4cb
 
 
 
 
 
8dfc799
 
3b7b628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5ab4cb
 
 
3b7b628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5ab4cb
cd47483
ab34078
cd47483
 
 
 
1bff30e
f5ab4cb
 
 
 
 
cd47483
 
4106f96
 
62bb2f6
 
f5ab4cb
 
a0cefd0
f5ab4cb
a0cefd0
f5ab4cb
a0cefd0
 
3b90025
85b97c4
 
 
62bb2f6
 
85b97c4
 
 
62bb2f6
 
4106f96
f5ab4cb
cd47483
4106f96
 
cd47483
2841b26
cd47483
 
 
5ac0c97
 
 
cd47483
f5ab4cb
 
 
 
 
 
cd47483
ab34078
 
cd47483
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import warnings

import argilla as rg

# Inference
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))

# Directory to locally save the generated data
SAVE_LOCAL_DIR = os.getenv(key="SAVE_LOCAL_DIR", default=None)

# Models
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")

# Just used in case of selecting a different model for completions
MODEL_COMPLETION = os.getenv("MODEL_COMPLETION", MODEL)
TOKENIZER_ID_COMPLETION = os.getenv("TOKENIZER_ID_COMPLETION", TOKENIZER_ID)
OPENAI_BASE_URL_COMPLETION = os.getenv("OPENAI_BASE_URL_COMPLETION", OPENAI_BASE_URL)
OLLAMA_BASE_URL_COMPLETION = os.getenv("OLLAMA_BASE_URL_COMPLETION", OLLAMA_BASE_URL)
HUGGINGFACE_BASE_URL_COMPLETION = os.getenv(
    "HUGGINGFACE_BASE_URL_COMPLETION", HUGGINGFACE_BASE_URL
)
VLLM_BASE_URL_COMPLETION = os.getenv("VLLM_BASE_URL_COMPLETION", VLLM_BASE_URL)

base_urls = [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
base_urls_completion = [
    OPENAI_BASE_URL_COMPLETION,
    OLLAMA_BASE_URL_COMPLETION,
    HUGGINGFACE_BASE_URL_COMPLETION,
    VLLM_BASE_URL_COMPLETION,
]


# Validate the configuration of the model and base URLs.
def validate_configuration(base_urls, model, env_context=""):
    huggingface_url = base_urls[2]
    if huggingface_url and model:
        raise ValueError(
            f"`HUGGINGFACE_BASE_URL{env_context}` and `MODEL{env_context}` cannot be set at the same time. "
            "Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
        )

    if not model and any(base_urls):
        raise ValueError(
            f"`MODEL{env_context}` is not set. Please provide a model id for inference."
        )

    active_urls = [url for url in base_urls if url]
    if len(active_urls) > 1:
        raise ValueError(
            f"Multiple base URLs are provided: {', '.join(active_urls)}. "
            "Only one base URL can be set at a time."
        )
validate_configuration(base_urls, MODEL)
validate_configuration(base_urls_completion, MODEL_COMPLETION, "_COMPLETION")

BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
BASE_URL_COMPLETION = (
    OPENAI_BASE_URL_COMPLETION
    or OLLAMA_BASE_URL_COMPLETION
    or HUGGINGFACE_BASE_URL_COMPLETION
    or VLLM_BASE_URL_COMPLETION
)

# API Keys
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError(
        "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
    )

_API_KEY = os.getenv("API_KEY")
API_KEYS = (
    [_API_KEY]
    if _API_KEY
    else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
)
API_KEYS = [token for token in API_KEYS if token]

# Determine if SFT is available
SFT_AVAILABLE = False
llama_options = ["llama3", "llama-3", "llama 3"]
qwen_options = ["qwen2", "qwen-2", "qwen 2"]

if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower():
    SFT_AVAILABLE = True
    if passed_pre_query_template in llama_options:
        MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
    elif passed_pre_query_template in qwen_options:
        MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
    else:
        MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
elif MODEL.lower() in llama_options or any(
    option in MODEL.lower() for option in llama_options
):
    SFT_AVAILABLE = True
    MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
elif MODEL.lower() in qwen_options or any(
    option in MODEL.lower() for option in qwen_options
):
    SFT_AVAILABLE = True
    MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"

if OPENAI_BASE_URL:
    SFT_AVAILABLE = False

if not SFT_AVAILABLE:
    warnings.warn(
        "`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm."
    )
    MAGPIE_PRE_QUERY_TEMPLATE = None

# Embeddings
STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"

# Argilla
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv(
    "ARGILLA_API_URL_SDG_REVIEWER"
)
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv(
    "ARGILLA_API_KEY_SDG_REVIEWER"
)

if not ARGILLA_API_URL or not ARGILLA_API_KEY:
    warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
    argilla_client = None
else:
    argilla_client = rg.Argilla(
        api_url=ARGILLA_API_URL,
        api_key=ARGILLA_API_KEY,
    )