File size: 3,976 Bytes
7d134e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from functools import lru_cache
from typing import Callable, Dict, List, Optional, Union

from ..errors import HfHubHTTPError, RepositoryNotFoundError, TemplateError
from ..utils import is_minijinja_available


def _import_minijinja():
    if not is_minijinja_available():
        raise ImportError("Cannot render template. Please install minijinja using `pip install minijinja`.")
    import minijinja  # noqa: F401

    return minijinja


def render_chat_prompt(
    *,
    model_id: str,
    messages: List[Dict[str, str]],
    token: Union[str, bool, None] = None,
    add_generation_prompt: bool = True,
    **kwargs,
) -> str:
    """Render a chat prompt using a model's chat template.

    Args:
        model_id (`str`):
            The model id.
        messages (`List[Dict[str, str]]`):
            The list of messages to render.
        token (`str` or `bool`, *optional*):
            Hugging Face token. Will default to the locally saved token if not provided.

    Returns:
        `str`: The rendered chat prompt.

    Raises:
        `TemplateError`: If there's any issue while fetching, compiling or rendering the chat template.
    """
    minijinja = _import_minijinja()
    template = _fetch_and_compile_template(model_id=model_id, token=token)

    try:
        return template(messages=messages, add_generation_prompt=add_generation_prompt, **kwargs)
    except minijinja.TemplateError as e:
        raise TemplateError(f"Error while trying to render chat prompt for model '{model_id}': {e}") from e


@lru_cache  # TODO: lru_cache for raised exceptions
def _fetch_and_compile_template(*, model_id: str, token: Union[str, None]) -> Callable:
    """Fetch and compile a model's chat template.

    Method is cached to avoid fetching the same model's config multiple times.

    Args:
        model_id (`str`):
            The model id.
        token (`str` or `bool`, *optional*):
            Hugging Face token. Will default to the locally saved token if not provided.

    Returns:
        `Callable`: A callable that takes a list of messages and returns the rendered chat prompt.
    """
    from huggingface_hub.hf_api import HfApi

    minijinja = _import_minijinja()

    # 1. fetch config from API
    try:
        config = HfApi(token=token).model_info(model_id).config
    except RepositoryNotFoundError as e:
        raise TemplateError(f"Cannot render chat template: model '{model_id}' not found.") from e
    except HfHubHTTPError as e:
        raise TemplateError(f"Error while trying to fetch chat template for model '{model_id}': {e}") from e

    # 2. check config validity
    if config is None:
        raise TemplateError(f"Config not found for model '{model_id}'.")
    tokenizer_config = config.get("tokenizer_config")
    if tokenizer_config is None:
        raise TemplateError(f"Tokenizer config not found for model '{model_id}'.")
    if tokenizer_config.get("chat_template") is None:
        raise TemplateError(f"Chat template not found in tokenizer_config for model '{model_id}'.")
    chat_template = tokenizer_config["chat_template"]
    if not isinstance(chat_template, str):
        raise TemplateError(f"Chat template must be a string, not '{type(chat_template)}' (model: {model_id}).")

    special_tokens: Dict[str, Optional[str]] = {}
    for key, value in tokenizer_config.items():
        if "token" in key:
            if isinstance(value, str):
                special_tokens[key] = value
            elif isinstance(value, dict) and value.get("__type") == "AddedToken":
                special_tokens[key] = value.get("content")

    # 3. compile template and return
    env = minijinja.Environment()
    try:
        env.add_template("chat_template", chat_template)
    except minijinja.TemplateError as e:
        raise TemplateError(f"Error while trying to compile chat template for model '{model_id}': {e}") from e
    return lambda **kwargs: env.render_template("chat_template", **kwargs, **special_tokens)