mgbam commited on
Commit
4cb627f
·
verified ·
1 Parent(s): 0181a1f

Create tests/test_models.py

Browse files
Files changed (1) hide show
  1. tests/test_models.py +74 -0
tests/test_models.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## tests/test_models.py
2
+ ```python
3
+ import pytest
4
+ from models import AVAILABLE_MODELS, find_model, ModelInfo
5
+
6
+ @pyte st.mark.parametrize("identifier, expected_id", [
7
+ ("Moonshot Kimi-K2", "moonshotai/Kimi-K2-Instruct"),
8
+ ("moonshotai/Kimi-K2-Instruct", "moonshotai/Kimi-K2-Instruct"),
9
+ ("openai/gpt-4", "openai/gpt-4"),
10
+ ])
11
+ def test_find_model(identifier, expected_id):
12
+ model = find_model(identifier)
13
+ assert isinstance(model, ModelInfo)
14
+ assert model.id == expected_id
15
+
16
+
17
+ def test_find_model_not_found():
18
+ assert find_model("nonexistent-model") is None
19
+
20
+
21
+ def test_available_models_have_unique_ids():
22
+ ids = [m.id for m in AVAILABLE_MODELS]
23
+ assert len(ids) == len(set(ids))
24
+ ```
25
+
26
+ ## tests/test_inference.py
27
+ ```python
28
+ import pytest
29
+ from inference import chat_completion, stream_chat_completion
30
+ from models import ModelInfo
31
+
32
+ class DummyClient:
33
+ def __init__(self, response):
34
+ self.response = response
35
+ self.chat = self
36
+ n
37
+ def completions(self, **kwargs):
38
+ class Choice: pass
39
+ choice = type('C', (), {'message': type('M', (), {'content': self.response})})
40
+ return type('R', (), {'choices': [choice]})
41
+
42
+ @pytest.fixture(autouse=True)
43
+ def patch_client(monkeypatch):
44
+ # Patch hf_client.get_inference_client
45
+ from hf_client import get_inference_client
46
+ def fake_client(model_id, provider):
47
+ client = DummyClient("hello world")
48
+ client.chat = client
49
+ client.chat.completions = client
50
+ return client
51
+ monkeypatch.setattr('hf_client.get_inference_client', fake_client)
52
+
53
+
54
+ def test_chat_completion_returns_text():
55
+ msg = [{'role': 'user', 'content': 'test'}]
56
+ result = chat_completion('any-model', msg)
57
+ assert isinstance(result, str)
58
+ assert result == 'hello world'
59
+
60
+
61
+ def test_stream_chat_completion_yields_chunks():
62
+ # simulate streaming
63
+ class StreamClient(DummyClient):
64
+ def completions(self, **kwargs):
65
+ # emulate generator
66
+ chunks = [type('C', (), {'choices': [type('Ch', (), {'delta': type('D', (), {'content': 'h'})})]}),
67
+ type('C', (), {'choices': [type('Ch', (), {'delta': type('D', (), {'content': 'i'})})]})]
68
+ return iter(chunks)
69
+ from hf_client import get_inference_client as real_get
70
+ monkeypatch.setattr('hf_client.get_inference_client', lambda mid, prov: StreamClient(None))
71
+
72
+ msg = [{'role': 'user', 'content': 'stream'}]
73
+ chunks = list(stream_chat_completion('any-model', msg))
74
+ assert ''.join(chunks) == 'hi'