import abc import logging from collections.abc import Sequence from typing import Any, Literal from llama_index.llms import ChatMessage, MessageRole from llama_index.llms.llama_utils import ( DEFAULT_SYSTEM_PROMPT, completion_to_prompt, messages_to_prompt, ) logger = logging.getLogger(__name__) class AbstractPromptStyle(abc.ABC): """Abstract class for prompt styles. This class is used to format a series of messages into a prompt that can be understood by the models. A series of messages represents the interaction(s) between a user and an assistant. This series of messages can be considered as a session between a user X and an assistant Y.This session holds, through the messages, the state of the conversation. This session, to be understood by the model, needs to be formatted into a prompt (i.e. a string that the models can understand). Prompts can be formatted in different ways, depending on the model. The implementations of this class represent the different ways to format a series of messages into a prompt. """ @abc.abstractmethod def __init__(self, *args: Any, **kwargs: Any) -> None: logger.debug("Initializing prompt_style=%s", self.__class__.__name__) @abc.abstractmethod def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: pass @abc.abstractmethod def _completion_to_prompt(self, completion: str) -> str: pass def messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: prompt = self._messages_to_prompt(messages) logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt) return prompt def completion_to_prompt(self, completion: str) -> str: prompt = self._completion_to_prompt(completion) logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt) return prompt class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC): _DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT def __init__(self, default_system_prompt: str | None) -> None: super().__init__() logger.debug("Got default_system_prompt='%s'", default_system_prompt) self.default_system_prompt = default_system_prompt class DefaultPromptStyle(AbstractPromptStyle): """Default prompt style that uses the defaults from llama_utils. It basically passes None to the LLM, indicating it should use the default functions. """ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) # Hacky way to override the functions # Override the functions to be None, and pass None to the LLM. self.messages_to_prompt = None # type: ignore[method-assign, assignment] self.completion_to_prompt = None # type: ignore[method-assign, assignment] def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: return "" def _completion_to_prompt(self, completion: str) -> str: return "" class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt): """Simple prompt style that just uses the default llama_utils functions. It transforms the sequence of messages into a prompt that should look like: ```text [INST] <> your system prompt here. <> user message here [/INST] assistant (model) response here ``` """ def __init__(self, default_system_prompt: str | None = None) -> None: # If no system prompt is given, the default one of the implementation is used. super().__init__(default_system_prompt=default_system_prompt) def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: return messages_to_prompt(messages, self.default_system_prompt) def _completion_to_prompt(self, completion: str) -> str: return completion_to_prompt(completion, self.default_system_prompt) class TagPromptStyle(AbstractPromptStyleWithSystemPrompt): """Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`. It transforms the sequence of messages into a prompt that should look like: ```text <|system|>: your system prompt here. <|user|>: user message here (possibly with context and question) <|assistant|>: assistant (model) response here. ``` FIXME: should we add surrounding `` and `` tags, like in llama2? """ def __init__(self, default_system_prompt: str | None = None) -> None: # We have to define a default system prompt here as the LLM will not # use the default llama_utils functions. default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT super().__init__(default_system_prompt) self.system_prompt: str = default_system_prompt def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: messages = list(messages) if messages[0].role != MessageRole.SYSTEM: logger.info( "Adding system_promt='%s' to the given messages as there are none given in the session", self.system_prompt, ) messages = [ ChatMessage(content=self.system_prompt, role=MessageRole.SYSTEM), *messages, ] return self._format_messages_to_prompt(messages) def _completion_to_prompt(self, completion: str) -> str: return ( f"<|system|>: {self.system_prompt.strip()}\n" f"<|user|>: {completion.strip()}\n" "<|assistant|>: " ) @staticmethod def _format_messages_to_prompt(messages: list[ChatMessage]) -> str: """Format message to prompt with `<|ROLE|>: MSG` style.""" assert messages[0].role == MessageRole.SYSTEM prompt = "" for message in messages: role = message.role content = message.content or "" message_from_user = f"<|{role.lower()}|>: {content.strip()}" message_from_user += "\n" prompt += message_from_user # we are missing the last <|assistant|> tag that will trigger a completion prompt += "<|assistant|>: " return prompt def get_prompt_style( prompt_style: Literal["default", "llama2", "tag"] | None ) -> type[AbstractPromptStyle]: """Get the prompt style to use from the given string. :param prompt_style: The prompt style to use. :return: The prompt style to use. """ if prompt_style is None or prompt_style == "default": return DefaultPromptStyle elif prompt_style == "llama2": return Llama2PromptStyle elif prompt_style == "tag": return TagPromptStyle raise ValueError(f"Unknown prompt_style='{prompt_style}'")