Spaces:
Running
on
Zero
Running
on
Zero
| from functools import lru_cache | |
| from typing import Callable, Dict, List, Optional, Union | |
| from ..errors import HfHubHTTPError, RepositoryNotFoundError, TemplateError | |
| from ..utils import is_minijinja_available | |
| def _import_minijinja(): | |
| if not is_minijinja_available(): | |
| raise ImportError("Cannot render template. Please install minijinja using `pip install minijinja`.") | |
| import minijinja # noqa: F401 | |
| return minijinja | |
| def render_chat_prompt( | |
| *, | |
| model_id: str, | |
| messages: List[Dict[str, str]], | |
| token: Union[str, bool, None] = None, | |
| add_generation_prompt: bool = True, | |
| **kwargs, | |
| ) -> str: | |
| """Render a chat prompt using a model's chat template. | |
| Args: | |
| model_id (`str`): | |
| The model id. | |
| messages (`List[Dict[str, str]]`): | |
| The list of messages to render. | |
| token (`str` or `bool`, *optional*): | |
| Hugging Face token. Will default to the locally saved token if not provided. | |
| Returns: | |
| `str`: The rendered chat prompt. | |
| Raises: | |
| `TemplateError`: If there's any issue while fetching, compiling or rendering the chat template. | |
| """ | |
| minijinja = _import_minijinja() | |
| template = _fetch_and_compile_template(model_id=model_id, token=token) | |
| try: | |
| return template(messages=messages, add_generation_prompt=add_generation_prompt, **kwargs) | |
| except minijinja.TemplateError as e: | |
| raise TemplateError(f"Error while trying to render chat prompt for model '{model_id}': {e}") from e | |
| # TODO: lru_cache for raised exceptions | |
| def _fetch_and_compile_template(*, model_id: str, token: Union[str, None]) -> Callable: | |
| """Fetch and compile a model's chat template. | |
| Method is cached to avoid fetching the same model's config multiple times. | |
| Args: | |
| model_id (`str`): | |
| The model id. | |
| token (`str` or `bool`, *optional*): | |
| Hugging Face token. Will default to the locally saved token if not provided. | |
| Returns: | |
| `Callable`: A callable that takes a list of messages and returns the rendered chat prompt. | |
| """ | |
| from huggingface_hub.hf_api import HfApi | |
| minijinja = _import_minijinja() | |
| # 1. fetch config from API | |
| try: | |
| config = HfApi(token=token).model_info(model_id).config | |
| except RepositoryNotFoundError as e: | |
| raise TemplateError(f"Cannot render chat template: model '{model_id}' not found.") from e | |
| except HfHubHTTPError as e: | |
| raise TemplateError(f"Error while trying to fetch chat template for model '{model_id}': {e}") from e | |
| # 2. check config validity | |
| if config is None: | |
| raise TemplateError(f"Config not found for model '{model_id}'.") | |
| tokenizer_config = config.get("tokenizer_config") | |
| if tokenizer_config is None: | |
| raise TemplateError(f"Tokenizer config not found for model '{model_id}'.") | |
| if tokenizer_config.get("chat_template") is None: | |
| raise TemplateError(f"Chat template not found in tokenizer_config for model '{model_id}'.") | |
| chat_template = tokenizer_config["chat_template"] | |
| if not isinstance(chat_template, str): | |
| raise TemplateError(f"Chat template must be a string, not '{type(chat_template)}' (model: {model_id}).") | |
| special_tokens: Dict[str, Optional[str]] = {} | |
| for key, value in tokenizer_config.items(): | |
| if "token" in key: | |
| if isinstance(value, str): | |
| special_tokens[key] = value | |
| elif isinstance(value, dict) and value.get("__type") == "AddedToken": | |
| special_tokens[key] = value.get("content") | |
| # 3. compile template and return | |
| env = minijinja.Environment() | |
| try: | |
| env.add_template("chat_template", chat_template) | |
| except minijinja.TemplateError as e: | |
| raise TemplateError(f"Error while trying to compile chat template for model '{model_id}': {e}") from e | |
| return lambda **kwargs: env.render_template("chat_template", **kwargs, **special_tokens) | |