Spaces:
Build error
Build error
| import pytest | |
| import requests | |
| import time | |
| from openai import OpenAI | |
| from utils import * | |
| server = ServerPreset.tinyllama2() | |
| def create_server(): | |
| global server | |
| server = ServerPreset.tinyllama2() | |
| def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool): | |
| global server | |
| server.start() | |
| res = server.make_request("POST", "/completion", data={ | |
| "n_predict": n_predict, | |
| "prompt": prompt, | |
| "return_tokens": return_tokens, | |
| }) | |
| assert res.status_code == 200 | |
| assert res.body["timings"]["prompt_n"] == n_prompt | |
| assert res.body["timings"]["predicted_n"] == n_predicted | |
| assert res.body["truncated"] == truncated | |
| assert type(res.body["has_new_line"]) == bool | |
| assert match_regex(re_content, res.body["content"]) | |
| if return_tokens: | |
| assert len(res.body["tokens"]) > 0 | |
| assert all(type(tok) == int for tok in res.body["tokens"]) | |
| else: | |
| assert res.body["tokens"] == [] | |
| def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): | |
| global server | |
| server.start() | |
| res = server.make_stream_request("POST", "/completion", data={ | |
| "n_predict": n_predict, | |
| "prompt": prompt, | |
| "stream": True, | |
| }) | |
| content = "" | |
| for data in res: | |
| assert "stop" in data and type(data["stop"]) == bool | |
| if data["stop"]: | |
| assert data["timings"]["prompt_n"] == n_prompt | |
| assert data["timings"]["predicted_n"] == n_predicted | |
| assert data["truncated"] == truncated | |
| assert data["stop_type"] == "limit" | |
| assert type(data["has_new_line"]) == bool | |
| assert "generation_settings" in data | |
| assert server.n_predict is not None | |
| assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict) | |
| assert data["generation_settings"]["seed"] == server.seed | |
| assert match_regex(re_content, content) | |
| else: | |
| assert len(data["tokens"]) > 0 | |
| assert all(type(tok) == int for tok in data["tokens"]) | |
| content += data["content"] | |
| def test_completion_stream_vs_non_stream(): | |
| global server | |
| server.start() | |
| res_stream = server.make_stream_request("POST", "/completion", data={ | |
| "n_predict": 8, | |
| "prompt": "I believe the meaning of life is", | |
| "stream": True, | |
| }) | |
| res_non_stream = server.make_request("POST", "/completion", data={ | |
| "n_predict": 8, | |
| "prompt": "I believe the meaning of life is", | |
| }) | |
| content_stream = "" | |
| for data in res_stream: | |
| content_stream += data["content"] | |
| assert content_stream == res_non_stream.body["content"] | |
| def test_completion_with_openai_library(): | |
| global server | |
| server.start() | |
| client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") | |
| res = client.completions.create( | |
| model="davinci-002", | |
| prompt="I believe the meaning of life is", | |
| max_tokens=8, | |
| ) | |
| assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b") | |
| assert res.choices[0].finish_reason == "length" | |
| assert res.choices[0].text is not None | |
| assert match_regex("(going|bed)+", res.choices[0].text) | |
| def test_completion_stream_with_openai_library(): | |
| global server | |
| server.start() | |
| client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") | |
| res = client.completions.create( | |
| model="davinci-002", | |
| prompt="I believe the meaning of life is", | |
| max_tokens=8, | |
| stream=True, | |
| ) | |
| output_text = '' | |
| for data in res: | |
| choice = data.choices[0] | |
| if choice.finish_reason is None: | |
| assert choice.text is not None | |
| output_text += choice.text | |
| assert match_regex("(going|bed)+", output_text) | |
| def test_consistent_result_same_seed(n_slots: int): | |
| global server | |
| server.n_slots = n_slots | |
| server.start() | |
| last_res = None | |
| for _ in range(4): | |
| res = server.make_request("POST", "/completion", data={ | |
| "prompt": "I believe the meaning of life is", | |
| "seed": 42, | |
| "temperature": 0.0, | |
| "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed | |
| }) | |
| if last_res is not None: | |
| assert res.body["content"] == last_res.body["content"] | |
| last_res = res | |
| def test_different_result_different_seed(n_slots: int): | |
| global server | |
| server.n_slots = n_slots | |
| server.start() | |
| last_res = None | |
| for seed in range(4): | |
| res = server.make_request("POST", "/completion", data={ | |
| "prompt": "I believe the meaning of life is", | |
| "seed": seed, | |
| "temperature": 1.0, | |
| "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed | |
| }) | |
| if last_res is not None: | |
| assert res.body["content"] != last_res.body["content"] | |
| last_res = res | |
| # TODO figure why it don't work with temperature = 1 | |
| # @pytest.mark.parametrize("temperature", [0.0, 1.0]) | |
| def test_consistent_result_different_batch_size(n_batch: int, temperature: float): | |
| global server | |
| server.n_batch = n_batch | |
| server.start() | |
| last_res = None | |
| for _ in range(4): | |
| res = server.make_request("POST", "/completion", data={ | |
| "prompt": "I believe the meaning of life is", | |
| "seed": 42, | |
| "temperature": temperature, | |
| "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed | |
| }) | |
| if last_res is not None: | |
| assert res.body["content"] == last_res.body["content"] | |
| last_res = res | |
| def test_cache_vs_nocache_prompt(): | |
| global server | |
| server.start() | |
| res_cache = server.make_request("POST", "/completion", data={ | |
| "prompt": "I believe the meaning of life is", | |
| "seed": 42, | |
| "temperature": 1.0, | |
| "cache_prompt": True, | |
| }) | |
| res_no_cache = server.make_request("POST", "/completion", data={ | |
| "prompt": "I believe the meaning of life is", | |
| "seed": 42, | |
| "temperature": 1.0, | |
| "cache_prompt": False, | |
| }) | |
| assert res_cache.body["content"] == res_no_cache.body["content"] | |
| def test_completion_with_tokens_input(): | |
| global server | |
| server.temperature = 0.0 | |
| server.start() | |
| prompt_str = "I believe the meaning of life is" | |
| res = server.make_request("POST", "/tokenize", data={ | |
| "content": prompt_str, | |
| "add_special": True, | |
| }) | |
| assert res.status_code == 200 | |
| tokens = res.body["tokens"] | |
| # single completion | |
| res = server.make_request("POST", "/completion", data={ | |
| "prompt": tokens, | |
| }) | |
| assert res.status_code == 200 | |
| assert type(res.body["content"]) == str | |
| # batch completion | |
| res = server.make_request("POST", "/completion", data={ | |
| "prompt": [tokens, tokens], | |
| }) | |
| assert res.status_code == 200 | |
| assert type(res.body) == list | |
| assert len(res.body) == 2 | |
| assert res.body[0]["content"] == res.body[1]["content"] | |
| # mixed string and tokens | |
| res = server.make_request("POST", "/completion", data={ | |
| "prompt": [tokens, prompt_str], | |
| }) | |
| assert res.status_code == 200 | |
| assert type(res.body) == list | |
| assert len(res.body) == 2 | |
| assert res.body[0]["content"] == res.body[1]["content"] | |
| # mixed string and tokens in one sequence | |
| res = server.make_request("POST", "/completion", data={ | |
| "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str], | |
| }) | |
| assert res.status_code == 200 | |
| assert type(res.body["content"]) == str | |
| def test_completion_parallel_slots(n_slots: int, n_requests: int): | |
| global server | |
| server.n_slots = n_slots | |
| server.temperature = 0.0 | |
| server.start() | |
| PROMPTS = [ | |
| ("Write a very long book.", "(very|special|big)+"), | |
| ("Write another a poem.", "(small|house)+"), | |
| ("What is LLM?", "(Dad|said)+"), | |
| ("The sky is blue and I love it.", "(climb|leaf)+"), | |
| ("Write another very long music lyrics.", "(friends|step|sky)+"), | |
| ("Write a very long joke.", "(cat|Whiskers)+"), | |
| ] | |
| def check_slots_status(): | |
| should_all_slots_busy = n_requests >= n_slots | |
| time.sleep(0.1) | |
| res = server.make_request("GET", "/slots") | |
| n_busy = sum([1 for slot in res.body if slot["is_processing"]]) | |
| if should_all_slots_busy: | |
| assert n_busy == n_slots | |
| else: | |
| assert n_busy <= n_slots | |
| tasks = [] | |
| for i in range(n_requests): | |
| prompt, re_content = PROMPTS[i % len(PROMPTS)] | |
| tasks.append((server.make_request, ("POST", "/completion", { | |
| "prompt": prompt, | |
| "seed": 42, | |
| "temperature": 1.0, | |
| }))) | |
| tasks.append((check_slots_status, ())) | |
| results = parallel_function_calls(tasks) | |
| # check results | |
| for i in range(n_requests): | |
| prompt, re_content = PROMPTS[i % len(PROMPTS)] | |
| res = results[i] | |
| assert res.status_code == 200 | |
| assert type(res.body["content"]) == str | |
| assert len(res.body["content"]) > 10 | |
| # FIXME: the result is not deterministic when using other slot than slot 0 | |
| # assert match_regex(re_content, res.body["content"]) | |
| def test_completion_response_fields( | |
| prompt: str, n_predict: int, response_fields: list[str] | |
| ): | |
| global server | |
| server.start() | |
| res = server.make_request( | |
| "POST", | |
| "/completion", | |
| data={ | |
| "n_predict": n_predict, | |
| "prompt": prompt, | |
| "response_fields": response_fields, | |
| }, | |
| ) | |
| assert res.status_code == 200 | |
| assert "content" in res.body | |
| assert len(res.body["content"]) | |
| if len(response_fields): | |
| assert res.body["generation_settings/n_predict"] == n_predict | |
| assert res.body["prompt"] == "<s> " + prompt | |
| assert isinstance(res.body["content"], str) | |
| assert len(res.body) == len(response_fields) | |
| else: | |
| assert len(res.body) | |
| assert "generation_settings" in res.body | |
| def test_n_probs(): | |
| global server | |
| server.start() | |
| res = server.make_request("POST", "/completion", data={ | |
| "prompt": "I believe the meaning of life is", | |
| "n_probs": 10, | |
| "temperature": 0.0, | |
| "n_predict": 5, | |
| }) | |
| assert res.status_code == 200 | |
| assert "completion_probabilities" in res.body | |
| assert len(res.body["completion_probabilities"]) == 5 | |
| for tok in res.body["completion_probabilities"]: | |
| assert "id" in tok and tok["id"] > 0 | |
| assert "token" in tok and type(tok["token"]) == str | |
| assert "logprob" in tok and tok["logprob"] <= 0.0 | |
| assert "bytes" in tok and type(tok["bytes"]) == list | |
| assert len(tok["top_logprobs"]) == 10 | |
| for prob in tok["top_logprobs"]: | |
| assert "id" in prob and prob["id"] > 0 | |
| assert "token" in prob and type(prob["token"]) == str | |
| assert "logprob" in prob and prob["logprob"] <= 0.0 | |
| assert "bytes" in prob and type(prob["bytes"]) == list | |
| def test_n_probs_stream(): | |
| global server | |
| server.start() | |
| res = server.make_stream_request("POST", "/completion", data={ | |
| "prompt": "I believe the meaning of life is", | |
| "n_probs": 10, | |
| "temperature": 0.0, | |
| "n_predict": 5, | |
| "stream": True, | |
| }) | |
| for data in res: | |
| if data["stop"] == False: | |
| assert "completion_probabilities" in data | |
| assert len(data["completion_probabilities"]) == 1 | |
| for tok in data["completion_probabilities"]: | |
| assert "id" in tok and tok["id"] > 0 | |
| assert "token" in tok and type(tok["token"]) == str | |
| assert "logprob" in tok and tok["logprob"] <= 0.0 | |
| assert "bytes" in tok and type(tok["bytes"]) == list | |
| assert len(tok["top_logprobs"]) == 10 | |
| for prob in tok["top_logprobs"]: | |
| assert "id" in prob and prob["id"] > 0 | |
| assert "token" in prob and type(prob["token"]) == str | |
| assert "logprob" in prob and prob["logprob"] <= 0.0 | |
| assert "bytes" in prob and type(prob["bytes"]) == list | |
| def test_n_probs_post_sampling(): | |
| global server | |
| server.start() | |
| res = server.make_request("POST", "/completion", data={ | |
| "prompt": "I believe the meaning of life is", | |
| "n_probs": 10, | |
| "temperature": 0.0, | |
| "n_predict": 5, | |
| "post_sampling_probs": True, | |
| }) | |
| assert res.status_code == 200 | |
| assert "completion_probabilities" in res.body | |
| assert len(res.body["completion_probabilities"]) == 5 | |
| for tok in res.body["completion_probabilities"]: | |
| assert "id" in tok and tok["id"] > 0 | |
| assert "token" in tok and type(tok["token"]) == str | |
| assert "prob" in tok and 0.0 < tok["prob"] <= 1.0 | |
| assert "bytes" in tok and type(tok["bytes"]) == list | |
| assert len(tok["top_probs"]) == 10 | |
| for prob in tok["top_probs"]: | |
| assert "id" in prob and prob["id"] > 0 | |
| assert "token" in prob and type(prob["token"]) == str | |
| assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 | |
| assert "bytes" in prob and type(prob["bytes"]) == list | |
| # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs | |
| assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) | |
| def test_cancel_request(): | |
| global server | |
| server.n_ctx = 4096 | |
| server.n_predict = -1 | |
| server.n_slots = 1 | |
| server.server_slots = True | |
| server.start() | |
| # send a request that will take a long time, but cancel it before it finishes | |
| try: | |
| server.make_request("POST", "/completion", data={ | |
| "prompt": "I believe the meaning of life is", | |
| }, timeout=0.1) | |
| except requests.exceptions.ReadTimeout: | |
| pass # expected | |
| # make sure the slot is free | |
| time.sleep(1) # wait for HTTP_POLLING_SECONDS | |
| res = server.make_request("GET", "/slots") | |
| assert res.body[0]["is_processing"] == False | |