Spaces:
Build error
Build error
import pytest | |
from utils import * | |
server = ServerPreset.tinyllama2() | |
def create_server(): | |
global server | |
server = ServerPreset.tinyllama2() | |
server.slot_save_path = "./tmp" | |
server.temperature = 0.0 | |
def test_slot_save_restore(): | |
global server | |
server.start() | |
# First prompt in slot 1 should be fully processed | |
res = server.make_request("POST", "/completion", data={ | |
"prompt": "What is the capital of France?", | |
"id_slot": 1, | |
"cache_prompt": True, | |
}) | |
assert res.status_code == 200 | |
assert match_regex("(Whiskers|Flana)+", res.body["content"]) | |
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed | |
# Save state of slot 1 | |
res = server.make_request("POST", "/slots/1?action=save", data={ | |
"filename": "slot1.bin", | |
}) | |
assert res.status_code == 200 | |
assert res.body["n_saved"] == 84 | |
# Since we have cache, this should only process the last tokens | |
res = server.make_request("POST", "/completion", data={ | |
"prompt": "What is the capital of Germany?", | |
"id_slot": 1, | |
"cache_prompt": True, | |
}) | |
assert res.status_code == 200 | |
assert match_regex("(Jack|said)+", res.body["content"]) | |
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed | |
# Loading the saved cache into slot 0 | |
res = server.make_request("POST", "/slots/0?action=restore", data={ | |
"filename": "slot1.bin", | |
}) | |
assert res.status_code == 200 | |
assert res.body["n_restored"] == 84 | |
# Since we have cache, slot 0 should only process the last tokens | |
res = server.make_request("POST", "/completion", data={ | |
"prompt": "What is the capital of Germany?", | |
"id_slot": 0, | |
"cache_prompt": True, | |
}) | |
assert res.status_code == 200 | |
assert match_regex("(Jack|said)+", res.body["content"]) | |
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed | |
# For verification that slot 1 was not corrupted during slot 0 load, same thing should work | |
res = server.make_request("POST", "/completion", data={ | |
"prompt": "What is the capital of Germany?", | |
"id_slot": 1, | |
"cache_prompt": True, | |
}) | |
assert res.status_code == 200 | |
assert match_regex("(Jack|said)+", res.body["content"]) | |
assert res.body["timings"]["prompt_n"] == 1 | |
def test_slot_erase(): | |
global server | |
server.start() | |
res = server.make_request("POST", "/completion", data={ | |
"prompt": "What is the capital of France?", | |
"id_slot": 1, | |
"cache_prompt": True, | |
}) | |
assert res.status_code == 200 | |
assert match_regex("(Whiskers|Flana)+", res.body["content"]) | |
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed | |
# erase slot 1 | |
res = server.make_request("POST", "/slots/1?action=erase") | |
assert res.status_code == 200 | |
# re-run the same prompt, it should process all tokens again | |
res = server.make_request("POST", "/completion", data={ | |
"prompt": "What is the capital of France?", | |
"id_slot": 1, | |
"cache_prompt": True, | |
}) | |
assert res.status_code == 200 | |
assert match_regex("(Whiskers|Flana)+", res.body["content"]) | |
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed | |