Spaces:
Runtime error
Runtime error
File size: 7,124 Bytes
db5855f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import argparse
import re
from pathlib import Path
import nbformat
import nbconvert
from traitlets.config import Config
# Notebooks that are excluded from the CI tests
EXCLUDED_NOTEBOOKS = ["data-preparation-ct-scan.ipynb", "pytorch-monai-training.ipynb"]
DEVICE_WIDGET = "device = widgets.Dropdown("
def disable_gradio_debug(nb, notebook_path):
found = False
for cell in nb["cells"]:
if "gradio" in cell["source"] and "debug" in cell["source"]:
found = True
cell["source"] = cell["source"].replace("debug=True", "debug=False")
if found:
print(f"Disabled gradio debug mode for {notebook_path}")
return nb
def disable_skip_ext(nb, notebook_path, test_device=""):
found = False
skip_for_device = None if test_device else False
for cell in nb["cells"]:
if test_device is not None and skip_for_device is None:
if (
'skip_for_device = "{}" in device.value'.format(test_device) in cell["source"]
and "to_quantize = widgets.Checkbox(value=not skip_for_device" in cell["source"]
):
skip_for_device = True
if "%%skip" in cell["source"]:
found = True
if not skip_for_device:
cell["source"] = re.sub(r"%%skip.*.\n", "\n", cell["source"])
else:
cell["source"] = '"""\n' + cell["source"] + '\n"""'
if found:
print(f"Disabled skip extension mode for {notebook_path}")
return nb
def remove_ov_install(cell):
updated_lines = []
def has_additional_deps(str_part):
if "%pip" in str_part:
return False
if "install" in str_part:
return False
if str_part.startswith("-"):
return False
if str_part.startswith("https://"):
return False
return True
lines = cell["source"].split("\n")
for line in lines:
if "openvino" in line:
updated_line_content = []
empty = True
package_found = False
for part in line.split(" "):
if "openvino-dev" in part:
package_found = True
continue
if "openvino-nightly" in part:
package_found = True
continue
if "openvino-tokenizers" in part:
package_found = True
continue
if "openvino>" in part or "openvino=" in part or "openvino" == part:
package_found = True
continue
if empty:
empty = not has_additional_deps(part)
updated_line_content.append(part)
if package_found:
if not empty:
updated_line = " ".join(updated_line_content)
if line.startswith(" "):
for token in line:
if token != " ":
break
# keep indention
updated_line = " " + updated_line
updated_lines.append(updated_line + "\n# " + line)
else:
updated_lines.append(line)
else:
updated_lines.append(line)
cell["source"] = "\n".join(updated_lines)
def patch_notebooks(notebooks_dir, test_device="", skip_ov_install=False):
"""
Patch notebooks in notebooks directory with replacement values
found in notebook metadata to speed up test execution.
This function is specific for the OpenVINO notebooks
Github Actions CI.
For example: change nr of epochs from 15 to 1 in
tensorflow-training-openvino-nncf.ipynb by adding
{"test_replace": {"epochs = 15": "epochs = 1"} to the cell
metadata of the cell that contains `epochs = 15`
:param notebooks_dir: Directory that contains the notebook subdirectories.
For example: openvino_notebooks/notebooks
"""
nb_convert_config = Config()
nb_convert_config.NotebookExporter.preprocessors = ["nbconvert.preprocessors.ClearOutputPreprocessor"]
output_remover = nbconvert.NotebookExporter(nb_convert_config)
for notebookfile in Path(notebooks_dir).glob("**/*.ipynb"):
if not str(notebookfile.name).startswith("test_") and notebookfile.name not in EXCLUDED_NOTEBOOKS:
nb = nbformat.read(notebookfile, as_version=nbformat.NO_CONVERT)
found = False
device_found = False
for cell in nb["cells"]:
if skip_ov_install and "%pip" in cell["source"]:
remove_ov_install(cell)
if test_device and DEVICE_WIDGET in cell["source"]:
device_found = True
cell["source"] = re.sub(r"value=.*,", f"value='{test_device.upper()}',", cell["source"])
cell["source"] = re.sub(
r"options=",
f"options=['{test_device.upper()}'] + ",
cell["source"],
)
print(f"Replaced testing device to {test_device}")
replace_dict = cell.get("metadata", {}).get("test_replace")
if replace_dict is not None:
found = True
for source_value, target_value in replace_dict.items():
if source_value not in cell["source"]:
raise ValueError(f"Processing {notebookfile} failed: {source_value} does not exist in cell")
cell["source"] = cell["source"].replace(source_value, target_value)
cell["source"] = "# Modified for testing\n" + cell["source"]
print(f"Processed {notebookfile}: {source_value} -> {target_value}")
if test_device and not device_found:
print(f"No device replacement found for {notebookfile}")
if not found:
print(f"No replacements found for {notebookfile}")
disable_gradio_debug(nb, notebookfile)
disable_skip_ext(nb, notebookfile, args.test_device)
nb_without_out, _ = output_remover.from_notebook_node(nb)
with notebookfile.with_name(f"test_{notebookfile.name}").open("w", encoding="utf-8") as out_file:
out_file.write(nb_without_out)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Notebook patcher")
parser.add_argument("notebooks_dir", default=".")
parser.add_argument("-td", "--test_device", default="")
parser.add_argument("--skip_ov_install", action="store_true")
args = parser.parse_args()
if not Path(args.notebooks_dir).is_dir():
raise ValueError(f"'{args.notebooks_dir}' is not an existing directory")
patch_notebooks(args.notebooks_dir, args.test_device, args.skip_ov_install)
|