Spaces:
Build error
Build error
import pytest | |
from utils import * | |
server = ServerPreset.tinyllama2() | |
LONG_TEXT = """ | |
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. | |
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. | |
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. | |
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. | |
""".strip() | |
def create_server(): | |
global server | |
server = ServerPreset.tinyllama2() | |
server.n_ctx = 256 | |
server.n_slots = 2 | |
def test_ctx_shift_enabled(): | |
# the prompt is 301 tokens | |
# the slot context is 256/2 = 128 tokens | |
# the prompt is truncated to keep the last 109 tokens | |
# 64 tokens are generated thanks to shifting the context when it gets full | |
global server | |
server.start() | |
res = server.make_request("POST", "/completion", data={ | |
"n_predict": 64, | |
"prompt": LONG_TEXT, | |
}) | |
assert res.status_code == 200 | |
assert res.body["timings"]["prompt_n"] == 109 | |
assert res.body["timings"]["predicted_n"] == 64 | |
assert res.body["truncated"] is True | |
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): | |
global server | |
server.disable_ctx_shift = True | |
server.n_predict = -1 | |
server.start() | |
res = server.make_request("POST", "/completion", data={ | |
"n_predict": n_predict, | |
"prompt": "Hi how are you", | |
}) | |
assert res.status_code == 200 | |
assert res.body["timings"]["predicted_n"] == n_token_output | |
assert res.body["truncated"] == truncated | |
def test_ctx_shift_disabled_long_prompt(): | |
global server | |
server.disable_ctx_shift = True | |
server.start() | |
res = server.make_request("POST", "/completion", data={ | |
"n_predict": 64, | |
"prompt": LONG_TEXT, | |
}) | |
assert res.status_code != 200 | |
assert "error" in res.body | |
assert "exceeds the available context size" in res.body["error"]["message"] | |