File size: 3,776 Bytes
a8b3f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from unittest.mock import Mock, patch

import pytest

from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.mixedbread.rerank.rerank import MixedBreadRerankModel


def test_validate_credentials():
    model = MixedBreadRerankModel()

    with pytest.raises(CredentialsValidateFailedError):
        model.validate_credentials(
            model="mxbai-rerank-large-v1",
            credentials={"api_key": "invalid_key"},
        )
    with patch("httpx.post") as mock_post:
        mock_response = Mock()
        mock_response.json.return_value = {
            "usage": {"prompt_tokens": 86, "total_tokens": 86},
            "model": "mixedbread-ai/mxbai-rerank-large-v1",
            "data": [
                {
                    "index": 0,
                    "score": 0.06762695,
                    "input": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
                    "States Census, Carson City had a population of 55,274.",
                    "object": "text_document",
                },
                {
                    "index": 1,
                    "score": 0.057403564,
                    "input": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific "
                    "Ocean that are a political division controlled by the United States. Its capital is "
                    "Saipan.",
                    "object": "text_document",
                },
            ],
            "object": "list",
            "top_k": 2,
            "return_input": True,
        }
        mock_response.status_code = 200
        mock_post.return_value = mock_response
        model.validate_credentials(
            model="mxbai-rerank-large-v1",
            credentials={
                "api_key": os.environ.get("MIXEDBREAD_API_KEY"),
            },
        )


def test_invoke_model():
    model = MixedBreadRerankModel()
    with patch("httpx.post") as mock_post:
        mock_response = Mock()
        mock_response.json.return_value = {
            "usage": {"prompt_tokens": 56, "total_tokens": 56},
            "model": "mixedbread-ai/mxbai-rerank-large-v1",
            "data": [
                {
                    "index": 0,
                    "score": 0.6044922,
                    "input": "Kasumi is a girl name of Japanese origin meaning mist.",
                    "object": "text_document",
                },
                {
                    "index": 1,
                    "score": 0.0703125,
                    "input": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a "
                    "team named PopiParty.",
                    "object": "text_document",
                },
            ],
            "object": "list",
            "top_k": 2,
            "return_input": "true",
        }
        mock_response.status_code = 200
        mock_post.return_value = mock_response
        result = model.invoke(
            model="mxbai-rerank-large-v1",
            credentials={
                "api_key": os.environ.get("MIXEDBREAD_API_KEY"),
            },
            query="Who is Kasumi?",
            docs=[
                "Kasumi is a girl name of Japanese origin meaning mist.",
                "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
                "PopiParty.",
            ],
            score_threshold=0.5,
        )

        assert isinstance(result, RerankResult)
        assert len(result.docs) == 1
        assert result.docs[0].index == 0
        assert result.docs[0].score >= 0.5