Spaces:
Runtime error
Runtime error
"""Nice output for Black. | |
The double calls are for patching purposes in tests. | |
""" | |
import json | |
import tempfile | |
from typing import Any, Optional | |
from click import echo, style | |
from mypy_extensions import mypyc_attr | |
def _out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: | |
if message is not None: | |
if "bold" not in styles: | |
styles["bold"] = True | |
message = style(message, **styles) | |
echo(message, nl=nl, err=True) | |
def _err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: | |
if message is not None: | |
if "fg" not in styles: | |
styles["fg"] = "red" | |
message = style(message, **styles) | |
echo(message, nl=nl, err=True) | |
def out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: | |
_out(message, nl=nl, **styles) | |
def err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: | |
_err(message, nl=nl, **styles) | |
def ipynb_diff(a: str, b: str, a_name: str, b_name: str) -> str: | |
"""Return a unified diff string between each cell in notebooks `a` and `b`.""" | |
a_nb = json.loads(a) | |
b_nb = json.loads(b) | |
diff_lines = [ | |
diff( | |
"".join(a_nb["cells"][cell_number]["source"]) + "\n", | |
"".join(b_nb["cells"][cell_number]["source"]) + "\n", | |
f"{a_name}:cell_{cell_number}", | |
f"{b_name}:cell_{cell_number}", | |
) | |
for cell_number, cell in enumerate(a_nb["cells"]) | |
if cell["cell_type"] == "code" | |
] | |
return "".join(diff_lines) | |
def diff(a: str, b: str, a_name: str, b_name: str) -> str: | |
"""Return a unified diff string between strings `a` and `b`.""" | |
import difflib | |
a_lines = a.splitlines(keepends=True) | |
b_lines = b.splitlines(keepends=True) | |
diff_lines = [] | |
for line in difflib.unified_diff( | |
a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5 | |
): | |
# Work around https://bugs.python.org/issue2142 | |
# See: | |
# https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html | |
if line[-1] == "\n": | |
diff_lines.append(line) | |
else: | |
diff_lines.append(line + "\n") | |
diff_lines.append("\\ No newline at end of file\n") | |
return "".join(diff_lines) | |
def color_diff(contents: str) -> str: | |
"""Inject the ANSI color codes to the diff.""" | |
lines = contents.split("\n") | |
for i, line in enumerate(lines): | |
if line.startswith("+++") or line.startswith("---"): | |
line = "\033[1m" + line + "\033[0m" # bold, reset | |
elif line.startswith("@@"): | |
line = "\033[36m" + line + "\033[0m" # cyan, reset | |
elif line.startswith("+"): | |
line = "\033[32m" + line + "\033[0m" # green, reset | |
elif line.startswith("-"): | |
line = "\033[31m" + line + "\033[0m" # red, reset | |
lines[i] = line | |
return "\n".join(lines) | |
def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str: | |
"""Dump `output` to a temporary file. Return path to the file.""" | |
with tempfile.NamedTemporaryFile( | |
mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" | |
) as f: | |
for lines in output: | |
f.write(lines) | |
if ensure_final_newline and lines and lines[-1] != "\n": | |
f.write("\n") | |
return f.name | |