mgbam commited on
Commit
4569a8b
·
verified ·
1 Parent(s): 668e783

Create test_inference.py

Browse files
Files changed (1) hide show
  1. tests/test_inference.py +43 -0
tests/test_inference.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_inference.py
2
+ import pytest
3
+ from inference import chat_completion, stream_chat_completion
4
+
5
+ class DummyStream:
6
+ def __init__(self, chunks):
7
+ self._chunks = chunks
8
+ def __iter__(self):
9
+ return iter(self._chunks)
10
+
11
+ class DummyClient:
12
+ def __init__(self, response):
13
+ self.response = response
14
+ self.chat = self
15
+ def completions(self, **kwargs):
16
+ return self
17
+ def create(self, **kwargs):
18
+ # if stream=True, return DummyStream
19
+ if kwargs.get("stream"):
20
+ from types import SimpleNamespace
21
+ chunks = [
22
+ SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="h"))]),
23
+ SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="i"))])
24
+ ]
25
+ return DummyStream(chunks)
26
+ # non-stream
27
+ from types import SimpleNamespace
28
+ return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self.response))])
29
+
30
+ @pytest.fixture(autouse=True)
31
+ def patch_client(monkeypatch):
32
+ from hf_client import get_inference_client
33
+ def fake(model_id, provider):
34
+ return DummyClient("hello")
35
+ monkeypatch.setattr('hf_client.get_inference_client', fake)
36
+
37
+ def test_chat_completion():
38
+ out = chat_completion("any-model", [{"role":"user","content":"hi"}])
39
+ assert out == "hello"
40
+
41
+ def test_stream_chat_completion():
42
+ chunks = list(stream_chat_completion("any-model", [{"role":"user","content":"stream"}]))
43
+ assert "".join(chunks) == "hi"