File size: 4,914 Bytes
cd47483 f5ab4cb 6a5179a 8291c8c f5ab4cb 8dfc799 3b7b628 f5ab4cb 3b7b628 f5ab4cb cd47483 ab34078 cd47483 1bff30e f5ab4cb cd47483 4106f96 62bb2f6 f5ab4cb a0cefd0 f5ab4cb a0cefd0 f5ab4cb a0cefd0 3b90025 85b97c4 62bb2f6 85b97c4 62bb2f6 4106f96 f5ab4cb cd47483 4106f96 cd47483 2841b26 cd47483 5ac0c97 cd47483 f5ab4cb cd47483 ab34078 cd47483 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import warnings
import argilla as rg
# Inference
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
# Directory to locally save the generated data
SAVE_LOCAL_DIR = os.getenv(key="SAVE_LOCAL_DIR", default=None)
# Models
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
# Just used in case of selecting a different model for completions
MODEL_COMPLETION = os.getenv("MODEL_COMPLETION", MODEL)
TOKENIZER_ID_COMPLETION = os.getenv("TOKENIZER_ID_COMPLETION", TOKENIZER_ID)
OPENAI_BASE_URL_COMPLETION = os.getenv("OPENAI_BASE_URL_COMPLETION", OPENAI_BASE_URL)
OLLAMA_BASE_URL_COMPLETION = os.getenv("OLLAMA_BASE_URL_COMPLETION", OLLAMA_BASE_URL)
HUGGINGFACE_BASE_URL_COMPLETION = os.getenv(
"HUGGINGFACE_BASE_URL_COMPLETION", HUGGINGFACE_BASE_URL
)
VLLM_BASE_URL_COMPLETION = os.getenv("VLLM_BASE_URL_COMPLETION", VLLM_BASE_URL)
base_urls = [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
base_urls_completion = [
OPENAI_BASE_URL_COMPLETION,
OLLAMA_BASE_URL_COMPLETION,
HUGGINGFACE_BASE_URL_COMPLETION,
VLLM_BASE_URL_COMPLETION,
]
# Validate the configuration of the model and base URLs.
def validate_configuration(base_urls, model, env_context=""):
huggingface_url = base_urls[2]
if huggingface_url and model:
raise ValueError(
f"`HUGGINGFACE_BASE_URL{env_context}` and `MODEL{env_context}` cannot be set at the same time. "
"Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
)
if not model and any(base_urls):
raise ValueError(
f"`MODEL{env_context}` is not set. Please provide a model id for inference."
)
active_urls = [url for url in base_urls if url]
if len(active_urls) > 1:
raise ValueError(
f"Multiple base URLs are provided: {', '.join(active_urls)}. "
"Only one base URL can be set at a time."
)
validate_configuration(base_urls, MODEL)
validate_configuration(base_urls_completion, MODEL_COMPLETION, "_COMPLETION")
BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
BASE_URL_COMPLETION = (
OPENAI_BASE_URL_COMPLETION
or OLLAMA_BASE_URL_COMPLETION
or HUGGINGFACE_BASE_URL_COMPLETION
or VLLM_BASE_URL_COMPLETION
)
# API Keys
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError(
"HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
)
_API_KEY = os.getenv("API_KEY")
API_KEYS = (
[_API_KEY]
if _API_KEY
else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
)
API_KEYS = [token for token in API_KEYS if token]
# Determine if SFT is available
SFT_AVAILABLE = False
llama_options = ["llama3", "llama-3", "llama 3"]
qwen_options = ["qwen2", "qwen-2", "qwen 2"]
if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower():
SFT_AVAILABLE = True
if passed_pre_query_template in llama_options:
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
elif passed_pre_query_template in qwen_options:
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
else:
MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
elif MODEL.lower() in llama_options or any(
option in MODEL.lower() for option in llama_options
):
SFT_AVAILABLE = True
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
elif MODEL.lower() in qwen_options or any(
option in MODEL.lower() for option in qwen_options
):
SFT_AVAILABLE = True
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
if OPENAI_BASE_URL:
SFT_AVAILABLE = False
if not SFT_AVAILABLE:
warnings.warn(
"`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm."
)
MAGPIE_PRE_QUERY_TEMPLATE = None
# Embeddings
STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
# Argilla
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv(
"ARGILLA_API_URL_SDG_REVIEWER"
)
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv(
"ARGILLA_API_KEY_SDG_REVIEWER"
)
if not ARGILLA_API_URL or not ARGILLA_API_KEY:
warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
argilla_client = None
else:
argilla_client = rg.Argilla(
api_url=ARGILLA_API_URL,
api_key=ARGILLA_API_KEY,
)
|