Spaces:
Sleeping
Sleeping
File size: 3,974 Bytes
cfd3735 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
"""Test Google PaLM Chat API wrapper."""
import pytest
from langchain.chat_models.google_palm import (
ChatGooglePalm,
ChatGooglePalmError,
_messages_to_prompt_dict,
)
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
)
def test_messages_to_prompt_dict_with_valid_messages() -> None:
pytest.importorskip("google.generativeai")
result = _messages_to_prompt_dict(
[
SystemMessage(content="Prompt"),
HumanMessage(example=True, content="Human example #1"),
AIMessage(example=True, content="AI example #1"),
HumanMessage(example=True, content="Human example #2"),
AIMessage(example=True, content="AI example #2"),
HumanMessage(content="Real human message"),
AIMessage(content="Real AI message"),
]
)
expected = {
"context": "Prompt",
"examples": [
{"author": "human", "content": "Human example #1"},
{"author": "ai", "content": "AI example #1"},
{"author": "human", "content": "Human example #2"},
{"author": "ai", "content": "AI example #2"},
],
"messages": [
{"author": "human", "content": "Real human message"},
{"author": "ai", "content": "Real AI message"},
],
}
assert result == expected
def test_messages_to_prompt_dict_raises_with_misplaced_system_message() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
HumanMessage(content="Real human message"),
SystemMessage(content="Prompt"),
]
)
assert "System message must be first" in str(e)
def test_messages_to_prompt_dict_raises_with_misordered_examples() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
AIMessage(example=True, content="AI example #1"),
HumanMessage(example=True, content="Human example #1"),
]
)
assert "AI example message must be immediately preceded" in str(e)
def test_messages_to_prompt_dict_raises_with_mismatched_examples() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
HumanMessage(example=True, content="Human example #1"),
AIMessage(example=False, content="AI example #1"),
]
)
assert "Human example message must be immediately followed" in str(e)
def test_messages_to_prompt_dict_raises_with_example_after_real() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ChatGooglePalmError) as e:
_messages_to_prompt_dict(
[
HumanMessage(example=False, content="Real message"),
HumanMessage(example=True, content="Human example #1"),
AIMessage(example=True, content="AI example #1"),
]
)
assert "Message examples must come before other" in str(e)
def test_chat_google_raises_with_invalid_temperature() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ValueError) as e:
ChatGooglePalm(google_api_key="fake", temperature=2.0)
assert "must be in the range" in str(e)
def test_chat_google_raises_with_invalid_top_p() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ValueError) as e:
ChatGooglePalm(google_api_key="fake", top_p=2.0)
assert "must be in the range" in str(e)
def test_chat_google_raises_with_invalid_top_k() -> None:
pytest.importorskip("google.generativeai")
with pytest.raises(ValueError) as e:
ChatGooglePalm(google_api_key="fake", top_k=-5)
assert "must be positive" in str(e)
|