sdiazlor's picture
refactor: add local save to README and improve layout
6a5179a
import os
import warnings
import argilla as rg
# Inference
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
# Directory to locally save the generated data
SAVE_LOCAL_DIR = os.getenv(key="SAVE_LOCAL_DIR", default=None)
# Models
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
TOKENIZER_ID = os.getenv(key="TOKENIZER_ID", default=None)
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
# Just used in case of selecting a different model for completions
MODEL_COMPLETION = os.getenv("MODEL_COMPLETION", MODEL)
TOKENIZER_ID_COMPLETION = os.getenv("TOKENIZER_ID_COMPLETION", TOKENIZER_ID)
OPENAI_BASE_URL_COMPLETION = os.getenv("OPENAI_BASE_URL_COMPLETION", OPENAI_BASE_URL)
OLLAMA_BASE_URL_COMPLETION = os.getenv("OLLAMA_BASE_URL_COMPLETION", OLLAMA_BASE_URL)
HUGGINGFACE_BASE_URL_COMPLETION = os.getenv(
"HUGGINGFACE_BASE_URL_COMPLETION", HUGGINGFACE_BASE_URL
)
VLLM_BASE_URL_COMPLETION = os.getenv("VLLM_BASE_URL_COMPLETION", VLLM_BASE_URL)
base_urls = [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
base_urls_completion = [
OPENAI_BASE_URL_COMPLETION,
OLLAMA_BASE_URL_COMPLETION,
HUGGINGFACE_BASE_URL_COMPLETION,
VLLM_BASE_URL_COMPLETION,
]
# Validate the configuration of the model and base URLs.
def validate_configuration(base_urls, model, env_context=""):
huggingface_url = base_urls[2]
if huggingface_url and model:
raise ValueError(
f"`HUGGINGFACE_BASE_URL{env_context}` and `MODEL{env_context}` cannot be set at the same time. "
"Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
)
if not model and any(base_urls):
raise ValueError(
f"`MODEL{env_context}` is not set. Please provide a model id for inference."
)
active_urls = [url for url in base_urls if url]
if len(active_urls) > 1:
raise ValueError(
f"Multiple base URLs are provided: {', '.join(active_urls)}. "
"Only one base URL can be set at a time."
)
validate_configuration(base_urls, MODEL)
validate_configuration(base_urls_completion, MODEL_COMPLETION, "_COMPLETION")
BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
BASE_URL_COMPLETION = (
OPENAI_BASE_URL_COMPLETION
or OLLAMA_BASE_URL_COMPLETION
or HUGGINGFACE_BASE_URL_COMPLETION
or VLLM_BASE_URL_COMPLETION
)
# API Keys
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError(
"HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
)
_API_KEY = os.getenv("API_KEY")
API_KEYS = (
[_API_KEY]
if _API_KEY
else [HF_TOKEN] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
)
API_KEYS = [token for token in API_KEYS if token]
# Determine if SFT is available
SFT_AVAILABLE = False
llama_options = ["llama3", "llama-3", "llama 3"]
qwen_options = ["qwen2", "qwen-2", "qwen 2"]
if passed_pre_query_template := os.getenv("MAGPIE_PRE_QUERY_TEMPLATE", "").lower():
SFT_AVAILABLE = True
if passed_pre_query_template in llama_options:
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
elif passed_pre_query_template in qwen_options:
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
else:
MAGPIE_PRE_QUERY_TEMPLATE = passed_pre_query_template
elif MODEL.lower() in llama_options or any(
option in MODEL.lower() for option in llama_options
):
SFT_AVAILABLE = True
MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
elif MODEL.lower() in qwen_options or any(
option in MODEL.lower() for option in qwen_options
):
SFT_AVAILABLE = True
MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
if OPENAI_BASE_URL:
SFT_AVAILABLE = False
if not SFT_AVAILABLE:
warnings.warn(
"`SFT_AVAILABLE` is set to `False`. Use Hugging Face Inference Endpoints or Ollama to generate chat data, provide a `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE`. You can also use `HUGGINGFACE_BASE_URL` to with vllm."
)
MAGPIE_PRE_QUERY_TEMPLATE = None
# Embeddings
STATIC_EMBEDDING_MODEL = "minishlab/potion-base-8M"
# Argilla
ARGILLA_API_URL = os.getenv("ARGILLA_API_URL") or os.getenv(
"ARGILLA_API_URL_SDG_REVIEWER"
)
ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY") or os.getenv(
"ARGILLA_API_KEY_SDG_REVIEWER"
)
if not ARGILLA_API_URL or not ARGILLA_API_KEY:
warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set or is empty")
argilla_client = None
else:
argilla_client = rg.Argilla(
api_url=ARGILLA_API_URL,
api_key=ARGILLA_API_KEY,
)