File size: 1,538 Bytes
4569a8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# tests/test_inference.py
import pytest
from inference import chat_completion, stream_chat_completion

class DummyStream:
    def __init__(self, chunks):
        self._chunks = chunks
    def __iter__(self):
        return iter(self._chunks)

class DummyClient:
    def __init__(self, response):
        self.response = response
        self.chat = self
    def completions(self, **kwargs):
        return self
    def create(self, **kwargs):
        # if stream=True, return DummyStream
        if kwargs.get("stream"):
            from types import SimpleNamespace
            chunks = [
                SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="h"))]),
                SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="i"))])
            ]
            return DummyStream(chunks)
        # non-stream
        from types import SimpleNamespace
        return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self.response))])

@pytest.fixture(autouse=True)
def patch_client(monkeypatch):
    from hf_client import get_inference_client
    def fake(model_id, provider):
        return DummyClient("hello")
    monkeypatch.setattr('hf_client.get_inference_client', fake)

def test_chat_completion():
    out = chat_completion("any-model", [{"role":"user","content":"hi"}])
    assert out == "hello"

def test_stream_chat_completion():
    chunks = list(stream_chat_completion("any-model", [{"role":"user","content":"stream"}]))
    assert "".join(chunks) == "hi"