File size: 4,140 Bytes
5a29263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python
'''

    This script fetches all the models used in the server tests.



    This is useful for slow tests that use larger models, to avoid them timing out on the model downloads.



    It is meant to be run from the root of the repository.



    Example:

        python scripts/fetch_server_test_models.py

        ( cd examples/server/tests && ./tests.sh -v -x -m slow )

'''
import ast
import glob
import logging
import os
from typing import Generator
from pydantic import BaseModel
from typing import Optional
import subprocess


class HuggingFaceModel(BaseModel):
    hf_repo: str
    hf_file: Optional[str] = None

    class Config:
        frozen = True


def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]:
    try:
        with open(test_file) as f:
            tree = ast.parse(f.read())
    except Exception as e:
        logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}')
        return

    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef):
            for dec in node.decorator_list:
                if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize':
                    param_names = ast.literal_eval(dec.args[0]).split(",")
                    if "hf_repo" not in param_names:
                        continue

                    raw_param_values = dec.args[1]
                    if not isinstance(raw_param_values, ast.List):
                        logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}')
                        continue

                    hf_repo_idx = param_names.index("hf_repo")
                    hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None

                    for t in raw_param_values.elts:
                        if not isinstance(t, ast.Tuple):
                            logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}')
                            continue
                        yield HuggingFaceModel(
                            hf_repo=ast.literal_eval(t.elts[hf_repo_idx]),
                            hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None)


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

    models = sorted(list(set([
        model
        for test_file in glob.glob('examples/server/tests/unit/test_*.py')
        for model in collect_hf_model_test_parameters(test_file)
    ])), key=lambda m: (m.hf_repo, m.hf_file))

    logging.info(f'Found {len(models)} models in parameterized tests:')
    for m in models:
        logging.info(f'  - {m.hf_repo} / {m.hf_file}')

    cli_path = os.environ.get(
        'LLAMA_SERVER_BIN_PATH',
        os.path.join(
            os.path.dirname(__file__),
            '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))

    for m in models:
        if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file):
            continue
        if m.hf_file is not None and '-of-' in m.hf_file:
            logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file')
            continue
        logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched')
        cmd = [
            cli_path,
            '-hfr', m.hf_repo,
            *([] if m.hf_file is None else ['-hff', m.hf_file]),
            '-n', '1',
            '-p', 'Hey',
            '--no-warmup',
            '--log-disable',
            '-no-cnv']
        if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo:
            cmd.append('-fa')
        try:
            subprocess.check_call(cmd)
        except subprocess.CalledProcessError:
            logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n  {" ".join(cmd)}')
            exit(1)