Spaces:
Build error
Build error
#!/usr/bin/env python | |
''' | |
This script fetches all the models used in the server tests. | |
This is useful for slow tests that use larger models, to avoid them timing out on the model downloads. | |
It is meant to be run from the root of the repository. | |
Example: | |
python scripts/fetch_server_test_models.py | |
( cd examples/server/tests && ./tests.sh -v -x -m slow ) | |
''' | |
import ast | |
import glob | |
import logging | |
import os | |
from typing import Generator | |
from pydantic import BaseModel | |
from typing import Optional | |
import subprocess | |
class HuggingFaceModel(BaseModel): | |
hf_repo: str | |
hf_file: Optional[str] = None | |
class Config: | |
frozen = True | |
def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]: | |
try: | |
with open(test_file) as f: | |
tree = ast.parse(f.read()) | |
except Exception as e: | |
logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}') | |
return | |
for node in ast.walk(tree): | |
if isinstance(node, ast.FunctionDef): | |
for dec in node.decorator_list: | |
if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize': | |
param_names = ast.literal_eval(dec.args[0]).split(",") | |
if "hf_repo" not in param_names: | |
continue | |
raw_param_values = dec.args[1] | |
if not isinstance(raw_param_values, ast.List): | |
logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}') | |
continue | |
hf_repo_idx = param_names.index("hf_repo") | |
hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None | |
for t in raw_param_values.elts: | |
if not isinstance(t, ast.Tuple): | |
logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}') | |
continue | |
yield HuggingFaceModel( | |
hf_repo=ast.literal_eval(t.elts[hf_repo_idx]), | |
hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None) | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') | |
models = sorted(list(set([ | |
model | |
for test_file in glob.glob('examples/server/tests/unit/test_*.py') | |
for model in collect_hf_model_test_parameters(test_file) | |
])), key=lambda m: (m.hf_repo, m.hf_file)) | |
logging.info(f'Found {len(models)} models in parameterized tests:') | |
for m in models: | |
logging.info(f' - {m.hf_repo} / {m.hf_file}') | |
cli_path = os.environ.get( | |
'LLAMA_SERVER_BIN_PATH', | |
os.path.join( | |
os.path.dirname(__file__), | |
'../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli')) | |
for m in models: | |
if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file): | |
continue | |
if m.hf_file is not None and '-of-' in m.hf_file: | |
logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file') | |
continue | |
logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched') | |
cmd = [ | |
cli_path, | |
'-hfr', m.hf_repo, | |
*([] if m.hf_file is None else ['-hff', m.hf_file]), | |
'-n', '1', | |
'-p', 'Hey', | |
'--no-warmup', | |
'--log-disable', | |
'-no-cnv'] | |
if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo: | |
cmd.append('-fa') | |
try: | |
subprocess.check_call(cmd) | |
except subprocess.CalledProcessError: | |
logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}') | |
exit(1) | |