Spaces:
Sleeping
Sleeping
File size: 6,797 Bytes
bf6d237 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import abc
import logging
from collections.abc import Sequence
from typing import Any, Literal
from llama_index.llms import ChatMessage, MessageRole
from llama_index.llms.llama_utils import (
DEFAULT_SYSTEM_PROMPT,
completion_to_prompt,
messages_to_prompt,
)
logger = logging.getLogger(__name__)
class AbstractPromptStyle(abc.ABC):
"""Abstract class for prompt styles.
This class is used to format a series of messages into a prompt that can be
understood by the models. A series of messages represents the interaction(s)
between a user and an assistant. This series of messages can be considered as a
session between a user X and an assistant Y.This session holds, through the
messages, the state of the conversation. This session, to be understood by the
model, needs to be formatted into a prompt (i.e. a string that the models
can understand). Prompts can be formatted in different ways,
depending on the model.
The implementations of this class represent the different ways to format a
series of messages into a prompt.
"""
@abc.abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.debug("Initializing prompt_style=%s", self.__class__.__name__)
@abc.abstractmethod
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
pass
@abc.abstractmethod
def _completion_to_prompt(self, completion: str) -> str:
pass
def messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = self._messages_to_prompt(messages)
logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt)
return prompt
def completion_to_prompt(self, completion: str) -> str:
prompt = self._completion_to_prompt(completion)
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
return prompt
class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
_DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT
def __init__(self, default_system_prompt: str | None) -> None:
super().__init__()
logger.debug("Got default_system_prompt='%s'", default_system_prompt)
self.default_system_prompt = default_system_prompt
class DefaultPromptStyle(AbstractPromptStyle):
"""Default prompt style that uses the defaults from llama_utils.
It basically passes None to the LLM, indicating it should use
the default functions.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# Hacky way to override the functions
# Override the functions to be None, and pass None to the LLM.
self.messages_to_prompt = None # type: ignore[method-assign, assignment]
self.completion_to_prompt = None # type: ignore[method-assign, assignment]
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return ""
def _completion_to_prompt(self, completion: str) -> str:
return ""
class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt):
"""Simple prompt style that just uses the default llama_utils functions.
It transforms the sequence of messages into a prompt that should look like:
```text
<s> [INST] <<SYS>> your system prompt here. <</SYS>>
user message here [/INST] assistant (model) response here </s>
```
"""
def __init__(self, default_system_prompt: str | None = None) -> None:
# If no system prompt is given, the default one of the implementation is used.
super().__init__(default_system_prompt=default_system_prompt)
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return messages_to_prompt(messages, self.default_system_prompt)
def _completion_to_prompt(self, completion: str) -> str:
return completion_to_prompt(completion, self.default_system_prompt)
class TagPromptStyle(AbstractPromptStyleWithSystemPrompt):
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
It transforms the sequence of messages into a prompt that should look like:
```text
<|system|>: your system prompt here.
<|user|>: user message here
(possibly with context and question)
<|assistant|>: assistant (model) response here.
```
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
"""
def __init__(self, default_system_prompt: str | None = None) -> None:
# We have to define a default system prompt here as the LLM will not
# use the default llama_utils functions.
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
super().__init__(default_system_prompt)
self.system_prompt: str = default_system_prompt
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
messages = list(messages)
if messages[0].role != MessageRole.SYSTEM:
logger.info(
"Adding system_promt='%s' to the given messages as there are none given in the session",
self.system_prompt,
)
messages = [
ChatMessage(content=self.system_prompt, role=MessageRole.SYSTEM),
*messages,
]
return self._format_messages_to_prompt(messages)
def _completion_to_prompt(self, completion: str) -> str:
return (
f"<|system|>: {self.system_prompt.strip()}\n"
f"<|user|>: {completion.strip()}\n"
"<|assistant|>: "
)
@staticmethod
def _format_messages_to_prompt(messages: list[ChatMessage]) -> str:
"""Format message to prompt with `<|ROLE|>: MSG` style."""
assert messages[0].role == MessageRole.SYSTEM
prompt = ""
for message in messages:
role = message.role
content = message.content or ""
message_from_user = f"<|{role.lower()}|>: {content.strip()}"
message_from_user += "\n"
prompt += message_from_user
# we are missing the last <|assistant|> tag that will trigger a completion
prompt += "<|assistant|>: "
return prompt
def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag"] | None
) -> type[AbstractPromptStyle]:
"""Get the prompt style to use from the given string.
:param prompt_style: The prompt style to use.
:return: The prompt style to use.
"""
if prompt_style is None or prompt_style == "default":
return DefaultPromptStyle
elif prompt_style == "llama2":
return Llama2PromptStyle
elif prompt_style == "tag":
return TagPromptStyle
raise ValueError(f"Unknown prompt_style='{prompt_style}'")
|