import gradio as gr from distilabel.llms import InferenceEndpointsLLM, OllamaLLM, OpenAILLM from synthetic_dataset_generator.constants import ( API_KEYS, HUGGINGFACE_BASE_URL, MAGPIE_PRE_QUERY_TEMPLATE, MODEL, OLLAMA_BASE_URL, OPENAI_BASE_URL, TOKENIZER_ID, ) TOKEN_INDEX = 0 def _get_next_api_key(): global TOKEN_INDEX api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)] TOKEN_INDEX += 1 return api_key def _get_llm(use_magpie_template=False, **kwargs): if OPENAI_BASE_URL: llm = OpenAILLM( model=MODEL, base_url=OPENAI_BASE_URL, api_key=_get_next_api_key(), **kwargs, ) if "generation_kwargs" in kwargs: if "stop_sequences" in kwargs["generation_kwargs"]: kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][ "stop_sequences" ] del kwargs["generation_kwargs"]["stop_sequences"] if "do_sample" in kwargs["generation_kwargs"]: del kwargs["generation_kwargs"]["do_sample"] elif OLLAMA_BASE_URL: if "generation_kwargs" in kwargs: if "max_new_tokens" in kwargs["generation_kwargs"]: kwargs["generation_kwargs"]["num_predict"] = kwargs[ "generation_kwargs" ]["max_new_tokens"] del kwargs["generation_kwargs"]["max_new_tokens"] if "stop_sequences" in kwargs["generation_kwargs"]: kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][ "stop_sequences" ] del kwargs["generation_kwargs"]["stop_sequences"] if "do_sample" in kwargs["generation_kwargs"]: del kwargs["generation_kwargs"]["do_sample"] options = kwargs["generation_kwargs"] del kwargs["generation_kwargs"] kwargs["generation_kwargs"] = {} kwargs["generation_kwargs"]["options"] = options llm = OllamaLLM( model=MODEL, host=OLLAMA_BASE_URL, tokenizer_id=TOKENIZER_ID or MODEL, **kwargs, ) elif HUGGINGFACE_BASE_URL: kwargs["generation_kwargs"]["do_sample"] = True llm = InferenceEndpointsLLM( api_key=_get_next_api_key(), base_url=HUGGINGFACE_BASE_URL, tokenizer_id=TOKENIZER_ID or MODEL, **kwargs, ) else: llm = InferenceEndpointsLLM( api_key=_get_next_api_key(), tokenizer_id=TOKENIZER_ID or MODEL, model_id=MODEL, magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE, **kwargs, ) return llm try: llm = _get_llm() llm.load() llm.generate([[{"content": "Hello, world!", "role": "user"}]]) except Exception as e: gr.Error(f"Error loading {llm.__class__.__name__}: {e}")