Spaces:
Sleeping
Sleeping
File size: 4,479 Bytes
cfd3735 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from typing import List
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import (
AIMessagePromptTemplate,
BaseMessagePromptTemplate,
ChatMessagePromptTemplate,
ChatPromptTemplate,
ChatPromptValue,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage
def create_messages() -> List[BaseMessagePromptTemplate]:
"""Create messages."""
system_message_prompt = SystemMessagePromptTemplate(
prompt=PromptTemplate(
template="Here's some context: {context}",
input_variables=["context"],
)
)
human_message_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
template="Hello {foo}, I'm {bar}. Thanks for the {context}",
input_variables=["foo", "bar", "context"],
)
)
ai_message_prompt = AIMessagePromptTemplate(
prompt=PromptTemplate(
template="I'm an AI. I'm {foo}. I'm {bar}.",
input_variables=["foo", "bar"],
)
)
chat_message_prompt = ChatMessagePromptTemplate(
role="test",
prompt=PromptTemplate(
template="I'm a generic message. I'm {foo}. I'm {bar}.",
input_variables=["foo", "bar"],
),
)
return [
system_message_prompt,
human_message_prompt,
ai_message_prompt,
chat_message_prompt,
]
def create_chat_prompt_template() -> ChatPromptTemplate:
"""Create a chat prompt template."""
return ChatPromptTemplate(
input_variables=["foo", "bar", "context"],
messages=create_messages(),
)
def test_create_chat_prompt_template_from_template() -> None:
"""Create a chat prompt template."""
prompt = ChatPromptTemplate.from_template("hi {foo} {bar}")
assert prompt.messages == [
HumanMessagePromptTemplate.from_template("hi {foo} {bar}")
]
def test_create_chat_prompt_template_from_template_partial() -> None:
"""Create a chat prompt template with partials."""
prompt = ChatPromptTemplate.from_template(
"hi {foo} {bar}", partial_variables={"foo": "jim"}
)
expected_prompt = PromptTemplate(
template="hi {foo} {bar}",
input_variables=["bar"],
partial_variables={"foo": "jim"},
)
assert len(prompt.messages) == 1
output_prompt = prompt.messages[0]
assert isinstance(output_prompt, HumanMessagePromptTemplate)
assert output_prompt.prompt == expected_prompt
def test_chat_prompt_template() -> None:
"""Test chat prompt template."""
prompt_template = create_chat_prompt_template()
prompt = prompt_template.format_prompt(foo="foo", bar="bar", context="context")
assert isinstance(prompt, ChatPromptValue)
messages = prompt.to_messages()
assert len(messages) == 4
assert messages[0].content == "Here's some context: context"
assert messages[1].content == "Hello foo, I'm bar. Thanks for the context"
assert messages[2].content == "I'm an AI. I'm foo. I'm bar."
assert messages[3].content == "I'm a generic message. I'm foo. I'm bar."
string = prompt.to_string()
expected = (
"System: Here's some context: context\n"
"Human: Hello foo, I'm bar. Thanks for the context\n"
"AI: I'm an AI. I'm foo. I'm bar.\n"
"test: I'm a generic message. I'm foo. I'm bar."
)
assert string == expected
string = prompt_template.format(foo="foo", bar="bar", context="context")
assert string == expected
def test_chat_prompt_template_from_messages() -> None:
"""Test creating a chat prompt template from messages."""
chat_prompt_template = ChatPromptTemplate.from_messages(create_messages())
assert sorted(chat_prompt_template.input_variables) == sorted(
["context", "foo", "bar"]
)
assert len(chat_prompt_template.messages) == 4
def test_chat_prompt_template_with_messages() -> None:
messages = create_messages() + [HumanMessage(content="foo")]
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
assert sorted(chat_prompt_template.input_variables) == sorted(
["context", "foo", "bar"]
)
assert len(chat_prompt_template.messages) == 5
prompt_value = chat_prompt_template.format_prompt(
context="see", foo="this", bar="magic"
)
prompt_value_messages = prompt_value.to_messages()
assert prompt_value_messages[-1] == HumanMessage(content="foo")
|