File size: 7,124 Bytes
db5855f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import argparse
import re
from pathlib import Path
import nbformat
import nbconvert
from traitlets.config import Config


# Notebooks that are excluded from the CI tests
EXCLUDED_NOTEBOOKS = ["data-preparation-ct-scan.ipynb", "pytorch-monai-training.ipynb"]

DEVICE_WIDGET = "device = widgets.Dropdown("


def disable_gradio_debug(nb, notebook_path):
    found = False
    for cell in nb["cells"]:
        if "gradio" in cell["source"] and "debug" in cell["source"]:
            found = True
            cell["source"] = cell["source"].replace("debug=True", "debug=False")

    if found:
        print(f"Disabled gradio debug mode for {notebook_path}")
    return nb


def disable_skip_ext(nb, notebook_path, test_device=""):
    found = False

    skip_for_device = None if test_device else False
    for cell in nb["cells"]:
        if test_device is not None and skip_for_device is None:
            if (
                'skip_for_device = "{}" in device.value'.format(test_device) in cell["source"]
                and "to_quantize = widgets.Checkbox(value=not skip_for_device" in cell["source"]
            ):
                skip_for_device = True

        if "%%skip" in cell["source"]:
            found = True
            if not skip_for_device:
                cell["source"] = re.sub(r"%%skip.*.\n", "\n", cell["source"])
            else:
                cell["source"] = '"""\n' + cell["source"] + '\n"""'
    if found:
        print(f"Disabled skip extension mode for {notebook_path}")
    return nb


def remove_ov_install(cell):
    updated_lines = []

    def has_additional_deps(str_part):
        if "%pip" in str_part:
            return False
        if "install" in str_part:
            return False
        if str_part.startswith("-"):
            return False
        if str_part.startswith("https://"):
            return False
        return True

    lines = cell["source"].split("\n")
    for line in lines:
        if "openvino" in line:
            updated_line_content = []
            empty = True
            package_found = False
            for part in line.split(" "):
                if "openvino-dev" in part:
                    package_found = True
                    continue
                if "openvino-nightly" in part:
                    package_found = True
                    continue
                if "openvino-tokenizers" in part:
                    package_found = True
                    continue
                if "openvino>" in part or "openvino=" in part or "openvino" == part:
                    package_found = True
                    continue
                if empty:
                    empty = not has_additional_deps(part)
                updated_line_content.append(part)

            if package_found:
                if not empty:
                    updated_line = " ".join(updated_line_content)
                    if line.startswith(" "):
                        for token in line:
                            if token != " ":
                                break
                            # keep indention
                            updated_line = " " + updated_line
                    updated_lines.append(updated_line + "\n# " + line)
            else:
                updated_lines.append(line)
        else:
            updated_lines.append(line)
    cell["source"] = "\n".join(updated_lines)


def patch_notebooks(notebooks_dir, test_device="", skip_ov_install=False):
    """

    Patch notebooks in notebooks directory with replacement values

    found in notebook metadata to speed up test execution.

    This function is specific for the OpenVINO notebooks

    Github Actions CI.



    For example: change nr of epochs from 15 to 1 in

    tensorflow-training-openvino-nncf.ipynb by adding

    {"test_replace": {"epochs = 15": "epochs = 1"} to the cell

    metadata of the cell that contains `epochs = 15`



    :param notebooks_dir: Directory that contains the notebook subdirectories.

                          For example: openvino_notebooks/notebooks

    """

    nb_convert_config = Config()
    nb_convert_config.NotebookExporter.preprocessors = ["nbconvert.preprocessors.ClearOutputPreprocessor"]
    output_remover = nbconvert.NotebookExporter(nb_convert_config)
    for notebookfile in Path(notebooks_dir).glob("**/*.ipynb"):
        if not str(notebookfile.name).startswith("test_") and notebookfile.name not in EXCLUDED_NOTEBOOKS:
            nb = nbformat.read(notebookfile, as_version=nbformat.NO_CONVERT)
            found = False
            device_found = False
            for cell in nb["cells"]:
                if skip_ov_install and "%pip" in cell["source"]:
                    remove_ov_install(cell)
                if test_device and DEVICE_WIDGET in cell["source"]:
                    device_found = True
                    cell["source"] = re.sub(r"value=.*,", f"value='{test_device.upper()}',", cell["source"])
                    cell["source"] = re.sub(
                        r"options=",
                        f"options=['{test_device.upper()}'] + ",
                        cell["source"],
                    )
                    print(f"Replaced testing device to {test_device}")
                replace_dict = cell.get("metadata", {}).get("test_replace")
                if replace_dict is not None:
                    found = True
                    for source_value, target_value in replace_dict.items():
                        if source_value not in cell["source"]:
                            raise ValueError(f"Processing {notebookfile} failed: {source_value} does not exist in cell")
                        cell["source"] = cell["source"].replace(source_value, target_value)
                        cell["source"] = "# Modified for testing\n" + cell["source"]
                        print(f"Processed {notebookfile}: {source_value} -> {target_value}")
            if test_device and not device_found:
                print(f"No device replacement found for {notebookfile}")
            if not found:
                print(f"No replacements found for {notebookfile}")
            disable_gradio_debug(nb, notebookfile)
            disable_skip_ext(nb, notebookfile, args.test_device)
            nb_without_out, _ = output_remover.from_notebook_node(nb)
            with notebookfile.with_name(f"test_{notebookfile.name}").open("w", encoding="utf-8") as out_file:
                out_file.write(nb_without_out)


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Notebook patcher")
    parser.add_argument("notebooks_dir", default=".")
    parser.add_argument("-td", "--test_device", default="")
    parser.add_argument("--skip_ov_install", action="store_true")
    args = parser.parse_args()
    if not Path(args.notebooks_dir).is_dir():
        raise ValueError(f"'{args.notebooks_dir}' is not an existing directory")
    patch_notebooks(args.notebooks_dir, args.test_device, args.skip_ov_install)