File size: 4,014 Bytes
a8b3f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from collections.abc import Generator

import google.generativeai.types.generation_types as generation_config_types
import pytest
from _pytest.monkeypatch import MonkeyPatch
from google.ai import generativelanguage as glm
from google.ai.generativelanguage_v1beta.types import content as gag_content
from google.generativeai import GenerativeModel
from google.generativeai.client import _ClientManager, configure
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
from google.generativeai.types.generation_types import BaseGenerateContentResponse

current_api_key = ""


class MockGoogleResponseClass:
    _done = False

    def __iter__(self):
        full_response_text = "it's google!"

        for i in range(0, len(full_response_text) + 1, 1):
            if i == len(full_response_text):
                self._done = True
                yield GenerateContentResponse(
                    done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
                )
            else:
                yield GenerateContentResponse(
                    done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
                )


class MockGoogleResponseCandidateClass:
    finish_reason = "stop"

    @property
    def content(self) -> gag_content.Content:
        return gag_content.Content(parts=[gag_content.Part(text="it's google!")])


class MockGoogleClass:
    @staticmethod
    def generate_content_sync() -> GenerateContentResponse:
        return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])

    @staticmethod
    def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
        return MockGoogleResponseClass()

    def generate_content(
        self: GenerativeModel,
        contents: content_types.ContentsType,
        *,
        generation_config: generation_config_types.GenerationConfigType | None = None,
        safety_settings: safety_types.SafetySettingOptions | None = None,
        stream: bool = False,
        **kwargs,
    ) -> GenerateContentResponse:
        global current_api_key

        if len(current_api_key) < 16:
            raise Exception("Invalid API key")

        if stream:
            return MockGoogleClass.generate_content_stream()

        return MockGoogleClass.generate_content_sync()

    @property
    def generative_response_text(self) -> str:
        return "it's google!"

    @property
    def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
        return [MockGoogleResponseCandidateClass()]

    def make_client(self: _ClientManager, name: str):
        global current_api_key

        if name.endswith("_async"):
            name = name.split("_")[0]
            cls = getattr(glm, name.title() + "ServiceAsyncClient")
        else:
            cls = getattr(glm, name.title() + "ServiceClient")

        # Attempt to configure using defaults.
        if not self.client_config:
            configure()

        client_options = self.client_config.get("client_options", None)
        if client_options:
            current_api_key = client_options.api_key

        def nop(self, *args, **kwargs):
            pass

        original_init = cls.__init__
        cls.__init__ = nop
        client: glm.GenerativeServiceClient = cls(**self.client_config)
        cls.__init__ = original_init

        if not self.default_metadata:
            return client


@pytest.fixture
def setup_google_mock(request, monkeypatch: MonkeyPatch):
    monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
    monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
    monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
    monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)

    yield

    monkeypatch.undo()