davidberenstein1957's picture
update deployment with API providers
f5ab4cb
raw
history blame
2.97 kB
import gradio as gr
from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM
from synthetic_dataset_generator.constants import (
API_KEYS,
HUGGINGFACE_BASE_URL,
MAGPIE_PRE_QUERY_TEMPLATE,
MODEL,
OLLAMA_BASE_URL,
OPENAI_BASE_URL,
TOKENIZER_ID,
)
TOKEN_INDEX = 0
def _get_next_api_key():
global TOKEN_INDEX
api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
TOKEN_INDEX += 1
return api_key
def _get_llm(use_magpie_template=False, **kwargs):
if OPENAI_BASE_URL:
llm = OpenAILLM(
model=MODEL,
base_url=OPENAI_BASE_URL,
api_key=_get_next_api_key(),
**kwargs,
)
if "generation_kwargs" in kwargs:
if "stop_sequences" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
"stop_sequences"
]
del kwargs["generation_kwargs"]["stop_sequences"]
if "do_sample" in kwargs["generation_kwargs"]:
del kwargs["generation_kwargs"]["do_sample"]
elif OLLAMA_BASE_URL:
if "generation_kwargs" in kwargs:
if "max_new_tokens" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["num_predict"] = kwargs[
"generation_kwargs"
]["max_new_tokens"]
del kwargs["generation_kwargs"]["max_new_tokens"]
if "stop_sequences" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
"stop_sequences"
]
del kwargs["generation_kwargs"]["stop_sequences"]
if "do_sample" in kwargs["generation_kwargs"]:
del kwargs["generation_kwargs"]["do_sample"]
options = kwargs["generation_kwargs"]
del kwargs["generation_kwargs"]
kwargs["generation_kwargs"] = {}
kwargs["generation_kwargs"]["options"] = options
llm = OllamaLLM(
model=MODEL,
host=OLLAMA_BASE_URL,
tokenizer_id=TOKENIZER_ID or MODEL,
**kwargs,
)
elif HUGGINGFACE_BASE_URL:
kwargs["generation_kwargs"]["do_sample"] = True
llm = InferenceEndpointsLLM(
api_key=_get_next_api_key(),
base_url=HUGGINGFACE_BASE_URL,
tokenizer_id=TOKENIZER_ID or MODEL,
**kwargs,
)
else:
llm = InferenceEndpointsLLM(
api_key=_get_next_api_key(),
tokenizer_id=TOKENIZER_ID or MODEL,
model_id=MODEL,
magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
**kwargs,
)
return llm
try:
llm = _get_llm()
llm.load()
llm.generate([[{"content": "Hello, world!", "role": "user"}]])
except Exception as e:
gr.Error(f"Error loading {llm.__class__.__name__}: {e}")