Spaces:
Running
Running
import os | |
import re | |
from typing import Union | |
import pytest | |
from _pytest.monkeypatch import MonkeyPatch | |
from requests import Response | |
from requests.exceptions import ConnectionError | |
from requests.sessions import Session | |
from xinference_client.client.restful.restful_client import ( | |
Client, | |
RESTfulChatglmCppChatModelHandle, | |
RESTfulChatModelHandle, | |
RESTfulEmbeddingModelHandle, | |
RESTfulGenerateModelHandle, | |
RESTfulRerankModelHandle, | |
) | |
from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage | |
class MockXinferenceClass: | |
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: | |
if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url): | |
raise RuntimeError('404 Not Found') | |
if 'generate' == model_uid: | |
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |
if 'chat' == model_uid: | |
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |
if 'embedding' == model_uid: | |
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |
if 'rerank' == model_uid: | |
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |
raise RuntimeError('404 Not Found') | |
def get(self: Session, url: str, **kwargs): | |
response = Response() | |
if 'v1/models/' in url: | |
# get model uid | |
model_uid = url.split('/')[-1] or '' | |
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ | |
model_uid not in ['generate', 'chat', 'embedding', 'rerank']: | |
response.status_code = 404 | |
response._content = b'{}' | |
return response | |
# check if url is valid | |
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): | |
response.status_code = 404 | |
response._content = b'{}' | |
return response | |
if model_uid in ['generate', 'chat']: | |
response.status_code = 200 | |
response._content = b'''{ | |
"model_type": "LLM", | |
"address": "127.0.0.1:43877", | |
"accelerators": [ | |
"0", | |
"1" | |
], | |
"model_name": "chatglm3-6b", | |
"model_lang": [ | |
"en" | |
], | |
"model_ability": [ | |
"generate", | |
"chat" | |
], | |
"model_description": "latest chatglm3", | |
"model_format": "pytorch", | |
"model_size_in_billions": 7, | |
"quantization": "none", | |
"model_hub": "huggingface", | |
"revision": null, | |
"context_length": 2048, | |
"replica": 1 | |
}''' | |
return response | |
elif model_uid == 'embedding': | |
response.status_code = 200 | |
response._content = b'''{ | |
"model_type": "embedding", | |
"address": "127.0.0.1:43877", | |
"accelerators": [ | |
"0", | |
"1" | |
], | |
"model_name": "bge", | |
"model_lang": [ | |
"en" | |
], | |
"revision": null, | |
"max_tokens": 512 | |
}''' | |
return response | |
elif 'v1/cluster/auth' in url: | |
response.status_code = 200 | |
response._content = b'''{ | |
"auth": true | |
}''' | |
return response | |
def _check_cluster_authenticated(self): | |
self._cluster_authed = True | |
def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int) -> dict: | |
# check if self._model_uid is a valid uuid | |
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ | |
self._model_uid != 'rerank': | |
raise RuntimeError('404 Not Found') | |
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url): | |
raise RuntimeError('404 Not Found') | |
if top_n is None: | |
top_n = 1 | |
return { | |
'results': [ | |
{ | |
'index': i, | |
'document': doc, | |
'relevance_score': 0.9 | |
} | |
for i, doc in enumerate(documents[:top_n]) | |
] | |
} | |
def create_embedding( | |
self: RESTfulGenerateModelHandle, | |
input: Union[str, list[str]], | |
**kwargs | |
) -> dict: | |
# check if self._model_uid is a valid uuid | |
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ | |
self._model_uid != 'embedding': | |
raise RuntimeError('404 Not Found') | |
if isinstance(input, str): | |
input = [input] | |
ipt_len = len(input) | |
embedding = Embedding( | |
object="list", | |
model=self._model_uid, | |
data=[ | |
EmbeddingData( | |
index=i, | |
object="embedding", | |
embedding=[1919.810 for _ in range(768)] | |
) | |
for i in range(ipt_len) | |
], | |
usage=EmbeddingUsage( | |
prompt_tokens=ipt_len, | |
total_tokens=ipt_len | |
) | |
) | |
return embedding | |
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' | |
def setup_xinference_mock(request, monkeypatch: MonkeyPatch): | |
if MOCK: | |
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) | |
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated) | |
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) | |
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) | |
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank) | |
yield | |
if MOCK: | |
monkeypatch.undo() |