Spaces:
Sleeping
Sleeping
import pytest | |
from llama_index.llms import ChatMessage, MessageRole | |
from private_gpt.components.llm.prompt_helper import ( | |
DefaultPromptStyle, | |
Llama2PromptStyle, | |
TagPromptStyle, | |
get_prompt_style, | |
) | |
def test_get_prompt_style_success(prompt_style, expected_prompt_style): | |
assert get_prompt_style(prompt_style) == expected_prompt_style | |
def test_get_prompt_style_failure(): | |
prompt_style = "unknown" | |
with pytest.raises(ValueError) as exc_info: | |
get_prompt_style(prompt_style) | |
assert str(exc_info.value) == f"Unknown prompt_style='{prompt_style}'" | |
def test_tag_prompt_style_format(): | |
prompt_style = TagPromptStyle() | |
messages = [ | |
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), | |
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
] | |
expected_prompt = ( | |
"<|system|>: You are an AI assistant.\n" | |
"<|user|>: Hello, how are you doing?\n" | |
"<|assistant|>: " | |
) | |
assert prompt_style.messages_to_prompt(messages) == expected_prompt | |
def test_tag_prompt_style_format_with_system_prompt(): | |
system_prompt = "This is a system prompt from configuration." | |
prompt_style = TagPromptStyle(default_system_prompt=system_prompt) | |
messages = [ | |
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
] | |
expected_prompt = ( | |
f"<|system|>: {system_prompt}\n" | |
"<|user|>: Hello, how are you doing?\n" | |
"<|assistant|>: " | |
) | |
assert prompt_style.messages_to_prompt(messages) == expected_prompt | |
messages = [ | |
ChatMessage( | |
content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM | |
), | |
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
] | |
expected_prompt = ( | |
"<|system|>: FOO BAR Custom sys prompt from messages.\n" | |
"<|user|>: Hello, how are you doing?\n" | |
"<|assistant|>: " | |
) | |
assert prompt_style.messages_to_prompt(messages) == expected_prompt | |
def test_llama2_prompt_style_format(): | |
prompt_style = Llama2PromptStyle() | |
messages = [ | |
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), | |
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
] | |
expected_prompt = ( | |
"<s> [INST] <<SYS>>\n" | |
" You are an AI assistant. \n" | |
"<</SYS>>\n" | |
"\n" | |
" Hello, how are you doing? [/INST]" | |
) | |
assert prompt_style.messages_to_prompt(messages) == expected_prompt | |
def test_llama2_prompt_style_with_system_prompt(): | |
system_prompt = "This is a system prompt from configuration." | |
prompt_style = Llama2PromptStyle(default_system_prompt=system_prompt) | |
messages = [ | |
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
] | |
expected_prompt = ( | |
"<s> [INST] <<SYS>>\n" | |
f" {system_prompt} \n" | |
"<</SYS>>\n" | |
"\n" | |
" Hello, how are you doing? [/INST]" | |
) | |
assert prompt_style.messages_to_prompt(messages) == expected_prompt | |
messages = [ | |
ChatMessage( | |
content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM | |
), | |
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
] | |
expected_prompt = ( | |
"<s> [INST] <<SYS>>\n" | |
" FOO BAR Custom sys prompt from messages. \n" | |
"<</SYS>>\n" | |
"\n" | |
" Hello, how are you doing? [/INST]" | |
) | |
assert prompt_style.messages_to_prompt(messages) == expected_prompt | |