File size: 3,851 Bytes
20f348c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from unittest.mock import MagicMock

import google.generativeai.types.generation_types as generation_config_types  # type: ignore
import pytest
from _pytest.monkeypatch import MonkeyPatch
from google.ai import generativelanguage as glm
from google.ai.generativelanguage_v1beta.types import content as gag_content
from google.generativeai import GenerativeModel
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
from google.generativeai.types.generation_types import BaseGenerateContentResponse

from extensions import ext_redis


class MockGoogleResponseClass:
    _done = False

    def __iter__(self):
        full_response_text = "it's google!"

        for i in range(0, len(full_response_text) + 1, 1):
            if i == len(full_response_text):
                self._done = True
                yield GenerateContentResponse(
                    done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
                )
            else:
                yield GenerateContentResponse(
                    done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
                )


class MockGoogleResponseCandidateClass:
    finish_reason = "stop"

    @property
    def content(self) -> gag_content.Content:
        return gag_content.Content(parts=[gag_content.Part(text="it's google!")])


class MockGoogleClass:
    @staticmethod
    def generate_content_sync() -> GenerateContentResponse:
        return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])

    @staticmethod
    def generate_content_stream() -> MockGoogleResponseClass:
        return MockGoogleResponseClass()

    def generate_content(
        self: GenerativeModel,
        contents: content_types.ContentsType,
        *,
        generation_config: generation_config_types.GenerationConfigType | None = None,
        safety_settings: safety_types.SafetySettingOptions | None = None,
        stream: bool = False,
        **kwargs,
    ) -> GenerateContentResponse:
        if stream:
            return MockGoogleClass.generate_content_stream()

        return MockGoogleClass.generate_content_sync()

    @property
    def generative_response_text(self) -> str:
        return "it's google!"

    @property
    def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
        return [MockGoogleResponseCandidateClass()]


def mock_configure(api_key: str):
    if len(api_key) < 16:
        raise Exception("Invalid API key")


class MockFileState:
    def __init__(self):
        self.name = "FINISHED"


class MockGoogleFile:
    def __init__(self, name: str = "mock_file_name"):
        self.name = name
        self.state = MockFileState()


def mock_get_file(name: str) -> MockGoogleFile:
    return MockGoogleFile(name)


def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
    return MockGoogleFile()


@pytest.fixture
def setup_google_mock(request, monkeypatch: MonkeyPatch):
    monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
    monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
    monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
    monkeypatch.setattr("google.generativeai.configure", mock_configure)
    monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
    monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)

    yield

    monkeypatch.undo()


@pytest.fixture
def setup_mock_redis() -> None:
    ext_redis.redis_client.get = MagicMock(return_value=None)
    ext_redis.redis_client.setex = MagicMock(return_value=None)
    ext_redis.redis_client.exists = MagicMock(return_value=True)