File size: 3,974 Bytes
cfd3735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Test Google PaLM Chat API wrapper."""

import pytest

from langchain.chat_models.google_palm import (
    ChatGooglePalm,
    ChatGooglePalmError,
    _messages_to_prompt_dict,
)
from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage,
)


def test_messages_to_prompt_dict_with_valid_messages() -> None:
    pytest.importorskip("google.generativeai")
    result = _messages_to_prompt_dict(
        [
            SystemMessage(content="Prompt"),
            HumanMessage(example=True, content="Human example #1"),
            AIMessage(example=True, content="AI example #1"),
            HumanMessage(example=True, content="Human example #2"),
            AIMessage(example=True, content="AI example #2"),
            HumanMessage(content="Real human message"),
            AIMessage(content="Real AI message"),
        ]
    )
    expected = {
        "context": "Prompt",
        "examples": [
            {"author": "human", "content": "Human example #1"},
            {"author": "ai", "content": "AI example #1"},
            {"author": "human", "content": "Human example #2"},
            {"author": "ai", "content": "AI example #2"},
        ],
        "messages": [
            {"author": "human", "content": "Real human message"},
            {"author": "ai", "content": "Real AI message"},
        ],
    }

    assert result == expected


def test_messages_to_prompt_dict_raises_with_misplaced_system_message() -> None:
    pytest.importorskip("google.generativeai")
    with pytest.raises(ChatGooglePalmError) as e:
        _messages_to_prompt_dict(
            [
                HumanMessage(content="Real human message"),
                SystemMessage(content="Prompt"),
            ]
        )
    assert "System message must be first" in str(e)


def test_messages_to_prompt_dict_raises_with_misordered_examples() -> None:
    pytest.importorskip("google.generativeai")
    with pytest.raises(ChatGooglePalmError) as e:
        _messages_to_prompt_dict(
            [
                AIMessage(example=True, content="AI example #1"),
                HumanMessage(example=True, content="Human example #1"),
            ]
        )
    assert "AI example message must be immediately preceded" in str(e)


def test_messages_to_prompt_dict_raises_with_mismatched_examples() -> None:
    pytest.importorskip("google.generativeai")
    with pytest.raises(ChatGooglePalmError) as e:
        _messages_to_prompt_dict(
            [
                HumanMessage(example=True, content="Human example #1"),
                AIMessage(example=False, content="AI example #1"),
            ]
        )
    assert "Human example message must be immediately followed" in str(e)


def test_messages_to_prompt_dict_raises_with_example_after_real() -> None:
    pytest.importorskip("google.generativeai")
    with pytest.raises(ChatGooglePalmError) as e:
        _messages_to_prompt_dict(
            [
                HumanMessage(example=False, content="Real message"),
                HumanMessage(example=True, content="Human example #1"),
                AIMessage(example=True, content="AI example #1"),
            ]
        )
    assert "Message examples must come before other" in str(e)


def test_chat_google_raises_with_invalid_temperature() -> None:
    pytest.importorskip("google.generativeai")
    with pytest.raises(ValueError) as e:
        ChatGooglePalm(google_api_key="fake", temperature=2.0)
    assert "must be in the range" in str(e)


def test_chat_google_raises_with_invalid_top_p() -> None:
    pytest.importorskip("google.generativeai")
    with pytest.raises(ValueError) as e:
        ChatGooglePalm(google_api_key="fake", top_p=2.0)
    assert "must be in the range" in str(e)


def test_chat_google_raises_with_invalid_top_k() -> None:
    pytest.importorskip("google.generativeai")
    with pytest.raises(ValueError) as e:
        ChatGooglePalm(google_api_key="fake", top_k=-5)
    assert "must be positive" in str(e)