|
import pytest |
|
from llama_index.llms import ChatMessage, MessageRole |
|
|
|
from private_gpt.components.llm.prompt_helper import ( |
|
DefaultPromptStyle, |
|
Llama2PromptStyle, |
|
TagPromptStyle, |
|
get_prompt_style, |
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
("prompt_style", "expected_prompt_style"), |
|
[ |
|
("default", DefaultPromptStyle), |
|
("llama2", Llama2PromptStyle), |
|
("tag", TagPromptStyle), |
|
], |
|
) |
|
def test_get_prompt_style_success(prompt_style, expected_prompt_style): |
|
assert get_prompt_style(prompt_style) == expected_prompt_style |
|
|
|
|
|
def test_get_prompt_style_failure(): |
|
prompt_style = "unknown" |
|
with pytest.raises(ValueError) as exc_info: |
|
get_prompt_style(prompt_style) |
|
assert str(exc_info.value) == f"Unknown prompt_style='{prompt_style}'" |
|
|
|
|
|
def test_tag_prompt_style_format(): |
|
prompt_style = TagPromptStyle() |
|
messages = [ |
|
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), |
|
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), |
|
] |
|
|
|
expected_prompt = ( |
|
"<|system|>: You are an AI assistant.\n" |
|
"<|user|>: Hello, how are you doing?\n" |
|
"<|assistant|>: " |
|
) |
|
|
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt |
|
|
|
|
|
def test_tag_prompt_style_format_with_system_prompt(): |
|
system_prompt = "This is a system prompt from configuration." |
|
prompt_style = TagPromptStyle(default_system_prompt=system_prompt) |
|
messages = [ |
|
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), |
|
] |
|
|
|
expected_prompt = ( |
|
f"<|system|>: {system_prompt}\n" |
|
"<|user|>: Hello, how are you doing?\n" |
|
"<|assistant|>: " |
|
) |
|
|
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt |
|
|
|
messages = [ |
|
ChatMessage( |
|
content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM |
|
), |
|
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), |
|
] |
|
|
|
expected_prompt = ( |
|
"<|system|>: FOO BAR Custom sys prompt from messages.\n" |
|
"<|user|>: Hello, how are you doing?\n" |
|
"<|assistant|>: " |
|
) |
|
|
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt |
|
|
|
|
|
def test_llama2_prompt_style_format(): |
|
prompt_style = Llama2PromptStyle() |
|
messages = [ |
|
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), |
|
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), |
|
] |
|
|
|
expected_prompt = ( |
|
"<s> [INST] <<SYS>>\n" |
|
" You are an AI assistant. \n" |
|
"<</SYS>>\n" |
|
"\n" |
|
" Hello, how are you doing? [/INST]" |
|
) |
|
|
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt |
|
|
|
|
|
def test_llama2_prompt_style_with_system_prompt(): |
|
system_prompt = "This is a system prompt from configuration." |
|
prompt_style = Llama2PromptStyle(default_system_prompt=system_prompt) |
|
messages = [ |
|
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), |
|
] |
|
|
|
expected_prompt = ( |
|
"<s> [INST] <<SYS>>\n" |
|
f" {system_prompt} \n" |
|
"<</SYS>>\n" |
|
"\n" |
|
" Hello, how are you doing? [/INST]" |
|
) |
|
|
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt |
|
|
|
messages = [ |
|
ChatMessage( |
|
content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM |
|
), |
|
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), |
|
] |
|
|
|
expected_prompt = ( |
|
"<s> [INST] <<SYS>>\n" |
|
" FOO BAR Custom sys prompt from messages. \n" |
|
"<</SYS>>\n" |
|
"\n" |
|
" Hello, how are you doing? [/INST]" |
|
) |
|
|
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt |
|
|