Spaces:
Runtime error
Runtime error
import io | |
import json | |
import platform | |
import re | |
import sys | |
import tokenize | |
import traceback | |
from contextlib import contextmanager | |
from dataclasses import replace | |
from datetime import datetime, timezone | |
from enum import Enum | |
from json.decoder import JSONDecodeError | |
from pathlib import Path | |
from typing import ( | |
Any, | |
Dict, | |
Generator, | |
Iterator, | |
List, | |
MutableMapping, | |
Optional, | |
Pattern, | |
Sequence, | |
Set, | |
Sized, | |
Tuple, | |
Union, | |
) | |
import click | |
from click.core import ParameterSource | |
from mypy_extensions import mypyc_attr | |
from pathspec import PathSpec | |
from pathspec.patterns.gitwildmatch import GitWildMatchPatternError | |
from _black_version import version as __version__ | |
from black.cache import Cache, get_cache_info, read_cache, write_cache | |
from black.comments import normalize_fmt_off | |
from black.const import ( | |
DEFAULT_EXCLUDES, | |
DEFAULT_INCLUDES, | |
DEFAULT_LINE_LENGTH, | |
STDIN_PLACEHOLDER, | |
) | |
from black.files import ( | |
find_project_root, | |
find_pyproject_toml, | |
find_user_pyproject_toml, | |
gen_python_files, | |
get_gitignore, | |
normalize_path_maybe_ignore, | |
parse_pyproject_toml, | |
wrap_stream_for_windows, | |
) | |
from black.handle_ipynb_magics import ( | |
PYTHON_CELL_MAGICS, | |
TRANSFORMED_MAGICS, | |
jupyter_dependencies_are_installed, | |
mask_cell, | |
put_trailing_semicolon_back, | |
remove_trailing_semicolon, | |
unmask_cell, | |
) | |
from black.linegen import LN, LineGenerator, transform_line | |
from black.lines import EmptyLineTracker, LinesBlock | |
from black.mode import ( | |
FUTURE_FLAG_TO_FEATURE, | |
VERSION_TO_FEATURES, | |
Feature, | |
Mode, | |
TargetVersion, | |
supports_feature, | |
) | |
from black.nodes import ( | |
STARS, | |
is_number_token, | |
is_simple_decorator_expression, | |
is_string_token, | |
syms, | |
) | |
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out | |
from black.parsing import InvalidInput # noqa F401 | |
from black.parsing import lib2to3_parse, parse_ast, stringify_ast | |
from black.report import Changed, NothingChanged, Report | |
from black.trans import iter_fexpr_spans | |
from blib2to3.pgen2 import token | |
from blib2to3.pytree import Leaf, Node | |
COMPILED = Path(__file__).suffix in (".pyd", ".so") | |
# types | |
FileContent = str | |
Encoding = str | |
NewLine = str | |
class WriteBack(Enum): | |
NO = 0 | |
YES = 1 | |
DIFF = 2 | |
CHECK = 3 | |
COLOR_DIFF = 4 | |
def from_configuration( | |
cls, *, check: bool, diff: bool, color: bool = False | |
) -> "WriteBack": | |
if check and not diff: | |
return cls.CHECK | |
if diff and color: | |
return cls.COLOR_DIFF | |
return cls.DIFF if diff else cls.YES | |
# Legacy name, left for integrations. | |
FileMode = Mode | |
def read_pyproject_toml( | |
ctx: click.Context, param: click.Parameter, value: Optional[str] | |
) -> Optional[str]: | |
"""Inject Black configuration from "pyproject.toml" into defaults in `ctx`. | |
Returns the path to a successfully found and read configuration file, None | |
otherwise. | |
""" | |
if not value: | |
value = find_pyproject_toml( | |
ctx.params.get("src", ()), ctx.params.get("stdin_filename", None) | |
) | |
if value is None: | |
return None | |
try: | |
config = parse_pyproject_toml(value) | |
except (OSError, ValueError) as e: | |
raise click.FileError( | |
filename=value, hint=f"Error reading configuration file: {e}" | |
) from None | |
if not config: | |
return None | |
else: | |
# Sanitize the values to be Click friendly. For more information please see: | |
# https://github.com/psf/black/issues/1458 | |
# https://github.com/pallets/click/issues/1567 | |
config = { | |
k: str(v) if not isinstance(v, (list, dict)) else v | |
for k, v in config.items() | |
} | |
target_version = config.get("target_version") | |
if target_version is not None and not isinstance(target_version, list): | |
raise click.BadOptionUsage( | |
"target-version", "Config key target-version must be a list" | |
) | |
exclude = config.get("exclude") | |
if exclude is not None and not isinstance(exclude, str): | |
raise click.BadOptionUsage("exclude", "Config key exclude must be a string") | |
extend_exclude = config.get("extend_exclude") | |
if extend_exclude is not None and not isinstance(extend_exclude, str): | |
raise click.BadOptionUsage( | |
"extend-exclude", "Config key extend-exclude must be a string" | |
) | |
default_map: Dict[str, Any] = {} | |
if ctx.default_map: | |
default_map.update(ctx.default_map) | |
default_map.update(config) | |
ctx.default_map = default_map | |
return value | |
def target_version_option_callback( | |
c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...] | |
) -> List[TargetVersion]: | |
"""Compute the target versions from a --target-version flag. | |
This is its own function because mypy couldn't infer the type correctly | |
when it was a lambda, causing mypyc trouble. | |
""" | |
return [TargetVersion[val.upper()] for val in v] | |
def re_compile_maybe_verbose(regex: str) -> Pattern[str]: | |
"""Compile a regular expression string in `regex`. | |
If it contains newlines, use verbose mode. | |
""" | |
if "\n" in regex: | |
regex = "(?x)" + regex | |
compiled: Pattern[str] = re.compile(regex) | |
return compiled | |
def validate_regex( | |
ctx: click.Context, | |
param: click.Parameter, | |
value: Optional[str], | |
) -> Optional[Pattern[str]]: | |
try: | |
return re_compile_maybe_verbose(value) if value is not None else None | |
except re.error as e: | |
raise click.BadParameter(f"Not a valid regular expression: {e}") from None | |
def main( # noqa: C901 | |
ctx: click.Context, | |
code: Optional[str], | |
line_length: int, | |
target_version: List[TargetVersion], | |
check: bool, | |
diff: bool, | |
color: bool, | |
fast: bool, | |
pyi: bool, | |
ipynb: bool, | |
python_cell_magics: Sequence[str], | |
skip_source_first_line: bool, | |
skip_string_normalization: bool, | |
skip_magic_trailing_comma: bool, | |
experimental_string_processing: bool, | |
preview: bool, | |
quiet: bool, | |
verbose: bool, | |
required_version: Optional[str], | |
include: Pattern[str], | |
exclude: Optional[Pattern[str]], | |
extend_exclude: Optional[Pattern[str]], | |
force_exclude: Optional[Pattern[str]], | |
stdin_filename: Optional[str], | |
workers: Optional[int], | |
src: Tuple[str, ...], | |
config: Optional[str], | |
) -> None: | |
"""The uncompromising code formatter.""" | |
ctx.ensure_object(dict) | |
if src and code is not None: | |
out( | |
main.get_usage(ctx) | |
+ "\n\n'SRC' and 'code' cannot be passed simultaneously." | |
) | |
ctx.exit(1) | |
if not src and code is None: | |
out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.") | |
ctx.exit(1) | |
root, method = ( | |
find_project_root(src, stdin_filename) if code is None else (None, None) | |
) | |
ctx.obj["root"] = root | |
if verbose: | |
if root: | |
out( | |
f"Identified `{root}` as project root containing a {method}.", | |
fg="blue", | |
) | |
if config: | |
config_source = ctx.get_parameter_source("config") | |
user_level_config = str(find_user_pyproject_toml()) | |
if config == user_level_config: | |
out( | |
"Using configuration from user-level config at " | |
f"'{user_level_config}'.", | |
fg="blue", | |
) | |
elif config_source in ( | |
ParameterSource.DEFAULT, | |
ParameterSource.DEFAULT_MAP, | |
): | |
out("Using configuration from project root.", fg="blue") | |
else: | |
out(f"Using configuration in '{config}'.", fg="blue") | |
if ctx.default_map: | |
for param, value in ctx.default_map.items(): | |
out(f"{param}: {value}") | |
error_msg = "Oh no! 💥 💔 💥" | |
if ( | |
required_version | |
and required_version != __version__ | |
and required_version != __version__.split(".")[0] | |
): | |
err( | |
f"{error_msg} The required version `{required_version}` does not match" | |
f" the running version `{__version__}`!" | |
) | |
ctx.exit(1) | |
if ipynb and pyi: | |
err("Cannot pass both `pyi` and `ipynb` flags!") | |
ctx.exit(1) | |
write_back = WriteBack.from_configuration(check=check, diff=diff, color=color) | |
if target_version: | |
versions = set(target_version) | |
else: | |
# We'll autodetect later. | |
versions = set() | |
mode = Mode( | |
target_versions=versions, | |
line_length=line_length, | |
is_pyi=pyi, | |
is_ipynb=ipynb, | |
skip_source_first_line=skip_source_first_line, | |
string_normalization=not skip_string_normalization, | |
magic_trailing_comma=not skip_magic_trailing_comma, | |
experimental_string_processing=experimental_string_processing, | |
preview=preview, | |
python_cell_magics=set(python_cell_magics), | |
) | |
if code is not None: | |
# Run in quiet mode by default with -c; the extra output isn't useful. | |
# You can still pass -v to get verbose output. | |
quiet = True | |
report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose) | |
if code is not None: | |
reformat_code( | |
content=code, fast=fast, write_back=write_back, mode=mode, report=report | |
) | |
else: | |
try: | |
sources = get_sources( | |
ctx=ctx, | |
src=src, | |
quiet=quiet, | |
verbose=verbose, | |
include=include, | |
exclude=exclude, | |
extend_exclude=extend_exclude, | |
force_exclude=force_exclude, | |
report=report, | |
stdin_filename=stdin_filename, | |
) | |
except GitWildMatchPatternError: | |
ctx.exit(1) | |
path_empty( | |
sources, | |
"No Python files are present to be formatted. Nothing to do 😴", | |
quiet, | |
verbose, | |
ctx, | |
) | |
if len(sources) == 1: | |
reformat_one( | |
src=sources.pop(), | |
fast=fast, | |
write_back=write_back, | |
mode=mode, | |
report=report, | |
) | |
else: | |
from black.concurrency import reformat_many | |
reformat_many( | |
sources=sources, | |
fast=fast, | |
write_back=write_back, | |
mode=mode, | |
report=report, | |
workers=workers, | |
) | |
if verbose or not quiet: | |
if code is None and (verbose or report.change_count or report.failure_count): | |
out() | |
out(error_msg if report.return_code else "All done! ✨ 🍰 ✨") | |
if code is None: | |
click.echo(str(report), err=True) | |
ctx.exit(report.return_code) | |
def get_sources( | |
*, | |
ctx: click.Context, | |
src: Tuple[str, ...], | |
quiet: bool, | |
verbose: bool, | |
include: Pattern[str], | |
exclude: Optional[Pattern[str]], | |
extend_exclude: Optional[Pattern[str]], | |
force_exclude: Optional[Pattern[str]], | |
report: "Report", | |
stdin_filename: Optional[str], | |
) -> Set[Path]: | |
"""Compute the set of files to be formatted.""" | |
sources: Set[Path] = set() | |
root = ctx.obj["root"] | |
using_default_exclude = exclude is None | |
exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude | |
gitignore: Optional[Dict[Path, PathSpec]] = None | |
root_gitignore = get_gitignore(root) | |
for s in src: | |
if s == "-" and stdin_filename: | |
p = Path(stdin_filename) | |
is_stdin = True | |
else: | |
p = Path(s) | |
is_stdin = False | |
if is_stdin or p.is_file(): | |
normalized_path: Optional[str] = normalize_path_maybe_ignore( | |
p, ctx.obj["root"], report | |
) | |
if normalized_path is None: | |
if verbose: | |
out(f'Skipping invalid source: "{normalized_path}"', fg="red") | |
continue | |
if verbose: | |
out(f'Found input source: "{normalized_path}"', fg="blue") | |
normalized_path = "/" + normalized_path | |
# Hard-exclude any files that matches the `--force-exclude` regex. | |
if force_exclude: | |
force_exclude_match = force_exclude.search(normalized_path) | |
else: | |
force_exclude_match = None | |
if force_exclude_match and force_exclude_match.group(0): | |
report.path_ignored(p, "matches the --force-exclude regular expression") | |
continue | |
if is_stdin: | |
p = Path(f"{STDIN_PLACEHOLDER}{str(p)}") | |
if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed( | |
verbose=verbose, quiet=quiet | |
): | |
continue | |
sources.add(p) | |
elif p.is_dir(): | |
p = root / normalize_path_maybe_ignore(p, ctx.obj["root"], report) | |
if verbose: | |
out(f'Found input source directory: "{p}"', fg="blue") | |
if using_default_exclude: | |
gitignore = { | |
root: root_gitignore, | |
p: get_gitignore(p), | |
} | |
sources.update( | |
gen_python_files( | |
p.iterdir(), | |
ctx.obj["root"], | |
include, | |
exclude, | |
extend_exclude, | |
force_exclude, | |
report, | |
gitignore, | |
verbose=verbose, | |
quiet=quiet, | |
) | |
) | |
elif s == "-": | |
if verbose: | |
out("Found input source stdin", fg="blue") | |
sources.add(p) | |
else: | |
err(f"invalid path: {s}") | |
return sources | |
def path_empty( | |
src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context | |
) -> None: | |
""" | |
Exit if there is no `src` provided for formatting | |
""" | |
if not src: | |
if verbose or not quiet: | |
out(msg) | |
ctx.exit(0) | |
def reformat_code( | |
content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report | |
) -> None: | |
""" | |
Reformat and print out `content` without spawning child processes. | |
Similar to `reformat_one`, but for string content. | |
`fast`, `write_back`, and `mode` options are passed to | |
:func:`format_file_in_place` or :func:`format_stdin_to_stdout`. | |
""" | |
path = Path("<string>") | |
try: | |
changed = Changed.NO | |
if format_stdin_to_stdout( | |
content=content, fast=fast, write_back=write_back, mode=mode | |
): | |
changed = Changed.YES | |
report.done(path, changed) | |
except Exception as exc: | |
if report.verbose: | |
traceback.print_exc() | |
report.failed(path, str(exc)) | |
# diff-shades depends on being to monkeypatch this function to operate. I know it's | |
# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26 | |
def reformat_one( | |
src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report" | |
) -> None: | |
"""Reformat a single file under `src` without spawning child processes. | |
`fast`, `write_back`, and `mode` options are passed to | |
:func:`format_file_in_place` or :func:`format_stdin_to_stdout`. | |
""" | |
try: | |
changed = Changed.NO | |
if str(src) == "-": | |
is_stdin = True | |
elif str(src).startswith(STDIN_PLACEHOLDER): | |
is_stdin = True | |
# Use the original name again in case we want to print something | |
# to the user | |
src = Path(str(src)[len(STDIN_PLACEHOLDER) :]) | |
else: | |
is_stdin = False | |
if is_stdin: | |
if src.suffix == ".pyi": | |
mode = replace(mode, is_pyi=True) | |
elif src.suffix == ".ipynb": | |
mode = replace(mode, is_ipynb=True) | |
if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode): | |
changed = Changed.YES | |
else: | |
cache: Cache = {} | |
if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF): | |
cache = read_cache(mode) | |
res_src = src.resolve() | |
res_src_s = str(res_src) | |
if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src): | |
changed = Changed.CACHED | |
if changed is not Changed.CACHED and format_file_in_place( | |
src, fast=fast, write_back=write_back, mode=mode | |
): | |
changed = Changed.YES | |
if (write_back is WriteBack.YES and changed is not Changed.CACHED) or ( | |
write_back is WriteBack.CHECK and changed is Changed.NO | |
): | |
write_cache(cache, [src], mode) | |
report.done(src, changed) | |
except Exception as exc: | |
if report.verbose: | |
traceback.print_exc() | |
report.failed(src, str(exc)) | |
def format_file_in_place( | |
src: Path, | |
fast: bool, | |
mode: Mode, | |
write_back: WriteBack = WriteBack.NO, | |
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy | |
) -> bool: | |
"""Format file under `src` path. Return True if changed. | |
If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted | |
code to the file. | |
`mode` and `fast` options are passed to :func:`format_file_contents`. | |
""" | |
if src.suffix == ".pyi": | |
mode = replace(mode, is_pyi=True) | |
elif src.suffix == ".ipynb": | |
mode = replace(mode, is_ipynb=True) | |
then = datetime.fromtimestamp(src.stat().st_mtime, timezone.utc) | |
header = b"" | |
with open(src, "rb") as buf: | |
if mode.skip_source_first_line: | |
header = buf.readline() | |
src_contents, encoding, newline = decode_bytes(buf.read()) | |
try: | |
dst_contents = format_file_contents(src_contents, fast=fast, mode=mode) | |
except NothingChanged: | |
return False | |
except JSONDecodeError: | |
raise ValueError( | |
f"File '{src}' cannot be parsed as valid Jupyter notebook." | |
) from None | |
src_contents = header.decode(encoding) + src_contents | |
dst_contents = header.decode(encoding) + dst_contents | |
if write_back == WriteBack.YES: | |
with open(src, "w", encoding=encoding, newline=newline) as f: | |
f.write(dst_contents) | |
elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): | |
now = datetime.now(timezone.utc) | |
src_name = f"{src}\t{then}" | |
dst_name = f"{src}\t{now}" | |
if mode.is_ipynb: | |
diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name) | |
else: | |
diff_contents = diff(src_contents, dst_contents, src_name, dst_name) | |
if write_back == WriteBack.COLOR_DIFF: | |
diff_contents = color_diff(diff_contents) | |
with lock or nullcontext(): | |
f = io.TextIOWrapper( | |
sys.stdout.buffer, | |
encoding=encoding, | |
newline=newline, | |
write_through=True, | |
) | |
f = wrap_stream_for_windows(f) | |
f.write(diff_contents) | |
f.detach() | |
return True | |
def format_stdin_to_stdout( | |
fast: bool, | |
*, | |
content: Optional[str] = None, | |
write_back: WriteBack = WriteBack.NO, | |
mode: Mode, | |
) -> bool: | |
"""Format file on stdin. Return True if changed. | |
If content is None, it's read from sys.stdin. | |
If `write_back` is YES, write reformatted code back to stdout. If it is DIFF, | |
write a diff to stdout. The `mode` argument is passed to | |
:func:`format_file_contents`. | |
""" | |
then = datetime.now(timezone.utc) | |
if content is None: | |
src, encoding, newline = decode_bytes(sys.stdin.buffer.read()) | |
else: | |
src, encoding, newline = content, "utf-8", "" | |
dst = src | |
try: | |
dst = format_file_contents(src, fast=fast, mode=mode) | |
return True | |
except NothingChanged: | |
return False | |
finally: | |
f = io.TextIOWrapper( | |
sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True | |
) | |
if write_back == WriteBack.YES: | |
# Make sure there's a newline after the content | |
if dst and dst[-1] != "\n": | |
dst += "\n" | |
f.write(dst) | |
elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): | |
now = datetime.now(timezone.utc) | |
src_name = f"STDIN\t{then}" | |
dst_name = f"STDOUT\t{now}" | |
d = diff(src, dst, src_name, dst_name) | |
if write_back == WriteBack.COLOR_DIFF: | |
d = color_diff(d) | |
f = wrap_stream_for_windows(f) | |
f.write(d) | |
f.detach() | |
def check_stability_and_equivalence( | |
src_contents: str, dst_contents: str, *, mode: Mode | |
) -> None: | |
"""Perform stability and equivalence checks. | |
Raise AssertionError if source and destination contents are not | |
equivalent, or if a second pass of the formatter would format the | |
content differently. | |
""" | |
assert_equivalent(src_contents, dst_contents) | |
assert_stable(src_contents, dst_contents, mode=mode) | |
def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: | |
"""Reformat contents of a file and return new contents. | |
If `fast` is False, additionally confirm that the reformatted code is | |
valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. | |
`mode` is passed to :func:`format_str`. | |
""" | |
if mode.is_ipynb: | |
dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode) | |
else: | |
dst_contents = format_str(src_contents, mode=mode) | |
if src_contents == dst_contents: | |
raise NothingChanged | |
if not fast and not mode.is_ipynb: | |
# Jupyter notebooks will already have been checked above. | |
check_stability_and_equivalence(src_contents, dst_contents, mode=mode) | |
return dst_contents | |
def validate_cell(src: str, mode: Mode) -> None: | |
"""Check that cell does not already contain TransformerManager transformations, | |
or non-Python cell magics, which might cause tokenizer_rt to break because of | |
indentations. | |
If a cell contains ``!ls``, then it'll be transformed to | |
``get_ipython().system('ls')``. However, if the cell originally contained | |
``get_ipython().system('ls')``, then it would get transformed in the same way: | |
>>> TransformerManager().transform_cell("get_ipython().system('ls')") | |
"get_ipython().system('ls')\n" | |
>>> TransformerManager().transform_cell("!ls") | |
"get_ipython().system('ls')\n" | |
Due to the impossibility of safely roundtripping in such situations, cells | |
containing transformed magics will be ignored. | |
""" | |
if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): | |
raise NothingChanged | |
if ( | |
src[:2] == "%%" | |
and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics | |
): | |
raise NothingChanged | |
def format_cell(src: str, *, fast: bool, mode: Mode) -> str: | |
"""Format code in given cell of Jupyter notebook. | |
General idea is: | |
- if cell has trailing semicolon, remove it; | |
- if cell has IPython magics, mask them; | |
- format cell; | |
- reinstate IPython magics; | |
- reinstate trailing semicolon (if originally present); | |
- strip trailing newlines. | |
Cells with syntax errors will not be processed, as they | |
could potentially be automagics or multi-line magics, which | |
are currently not supported. | |
""" | |
validate_cell(src, mode) | |
src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon( | |
src | |
) | |
try: | |
masked_src, replacements = mask_cell(src_without_trailing_semicolon) | |
except SyntaxError: | |
raise NothingChanged from None | |
masked_dst = format_str(masked_src, mode=mode) | |
if not fast: | |
check_stability_and_equivalence(masked_src, masked_dst, mode=mode) | |
dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements) | |
dst = put_trailing_semicolon_back( | |
dst_without_trailing_semicolon, has_trailing_semicolon | |
) | |
dst = dst.rstrip("\n") | |
if dst == src: | |
raise NothingChanged from None | |
return dst | |
def validate_metadata(nb: MutableMapping[str, Any]) -> None: | |
"""If notebook is marked as non-Python, don't format it. | |
All notebook metadata fields are optional, see | |
https://nbformat.readthedocs.io/en/latest/format_description.html. So | |
if a notebook has empty metadata, we will try to parse it anyway. | |
""" | |
language = nb.get("metadata", {}).get("language_info", {}).get("name", None) | |
if language is not None and language != "python": | |
raise NothingChanged from None | |
def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: | |
"""Format Jupyter notebook. | |
Operate cell-by-cell, only on code cells, only for Python notebooks. | |
If the ``.ipynb`` originally had a trailing newline, it'll be preserved. | |
""" | |
if not src_contents: | |
raise NothingChanged | |
trailing_newline = src_contents[-1] == "\n" | |
modified = False | |
nb = json.loads(src_contents) | |
validate_metadata(nb) | |
for cell in nb["cells"]: | |
if cell.get("cell_type", None) == "code": | |
try: | |
src = "".join(cell["source"]) | |
dst = format_cell(src, fast=fast, mode=mode) | |
except NothingChanged: | |
pass | |
else: | |
cell["source"] = dst.splitlines(keepends=True) | |
modified = True | |
if modified: | |
dst_contents = json.dumps(nb, indent=1, ensure_ascii=False) | |
if trailing_newline: | |
dst_contents = dst_contents + "\n" | |
return dst_contents | |
else: | |
raise NothingChanged | |
def format_str(src_contents: str, *, mode: Mode) -> str: | |
"""Reformat a string and return new contents. | |
`mode` determines formatting options, such as how many characters per line are | |
allowed. Example: | |
>>> import black | |
>>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode())) | |
def f(arg: str = "") -> None: | |
... | |
A more complex example: | |
>>> print( | |
... black.format_str( | |
... "def f(arg:str='')->None: hey", | |
... mode=black.Mode( | |
... target_versions={black.TargetVersion.PY36}, | |
... line_length=10, | |
... string_normalization=False, | |
... is_pyi=False, | |
... ), | |
... ), | |
... ) | |
def f( | |
arg: str = '', | |
) -> None: | |
hey | |
""" | |
dst_contents = _format_str_once(src_contents, mode=mode) | |
# Forced second pass to work around optional trailing commas (becoming | |
# forced trailing commas on pass 2) interacting differently with optional | |
# parentheses. Admittedly ugly. | |
if src_contents != dst_contents: | |
return _format_str_once(dst_contents, mode=mode) | |
return dst_contents | |
def _format_str_once(src_contents: str, *, mode: Mode) -> str: | |
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) | |
dst_blocks: List[LinesBlock] = [] | |
if mode.target_versions: | |
versions = mode.target_versions | |
else: | |
future_imports = get_future_imports(src_node) | |
versions = detect_target_versions(src_node, future_imports=future_imports) | |
context_manager_features = { | |
feature | |
for feature in {Feature.PARENTHESIZED_CONTEXT_MANAGERS} | |
if supports_feature(versions, feature) | |
} | |
normalize_fmt_off(src_node) | |
lines = LineGenerator(mode=mode, features=context_manager_features) | |
elt = EmptyLineTracker(mode=mode) | |
split_line_features = { | |
feature | |
for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF} | |
if supports_feature(versions, feature) | |
} | |
block: Optional[LinesBlock] = None | |
for current_line in lines.visit(src_node): | |
block = elt.maybe_empty_lines(current_line) | |
dst_blocks.append(block) | |
for line in transform_line( | |
current_line, mode=mode, features=split_line_features | |
): | |
block.content_lines.append(str(line)) | |
if dst_blocks: | |
dst_blocks[-1].after = 0 | |
dst_contents = [] | |
for block in dst_blocks: | |
dst_contents.extend(block.all_lines()) | |
if not dst_contents: | |
# Use decode_bytes to retrieve the correct source newline (CRLF or LF), | |
# and check if normalized_content has more than one line | |
normalized_content, _, newline = decode_bytes(src_contents.encode("utf-8")) | |
if "\n" in normalized_content: | |
return newline | |
return "" | |
return "".join(dst_contents) | |
def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: | |
"""Return a tuple of (decoded_contents, encoding, newline). | |
`newline` is either CRLF or LF but `decoded_contents` is decoded with | |
universal newlines (i.e. only contains LF). | |
""" | |
srcbuf = io.BytesIO(src) | |
encoding, lines = tokenize.detect_encoding(srcbuf.readline) | |
if not lines: | |
return "", encoding, "\n" | |
newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n" | |
srcbuf.seek(0) | |
with io.TextIOWrapper(srcbuf, encoding) as tiow: | |
return tiow.read(), encoding, newline | |
def get_features_used( # noqa: C901 | |
node: Node, *, future_imports: Optional[Set[str]] = None | |
) -> Set[Feature]: | |
"""Return a set of (relatively) new Python features used in this file. | |
Currently looking for: | |
- f-strings; | |
- self-documenting expressions in f-strings (f"{x=}"); | |
- underscores in numeric literals; | |
- trailing commas after * or ** in function signatures and calls; | |
- positional only arguments in function signatures and lambdas; | |
- assignment expression; | |
- relaxed decorator syntax; | |
- usage of __future__ flags (annotations); | |
- print / exec statements; | |
- parenthesized context managers; | |
- match statements; | |
- except* clause; | |
- variadic generics; | |
""" | |
features: Set[Feature] = set() | |
if future_imports: | |
features |= { | |
FUTURE_FLAG_TO_FEATURE[future_import] | |
for future_import in future_imports | |
if future_import in FUTURE_FLAG_TO_FEATURE | |
} | |
for n in node.pre_order(): | |
if is_string_token(n): | |
value_head = n.value[:2] | |
if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}: | |
features.add(Feature.F_STRINGS) | |
if Feature.DEBUG_F_STRINGS not in features: | |
for span_beg, span_end in iter_fexpr_spans(n.value): | |
if n.value[span_beg : span_end - 1].rstrip().endswith("="): | |
features.add(Feature.DEBUG_F_STRINGS) | |
break | |
elif is_number_token(n): | |
if "_" in n.value: | |
features.add(Feature.NUMERIC_UNDERSCORES) | |
elif n.type == token.SLASH: | |
if n.parent and n.parent.type in { | |
syms.typedargslist, | |
syms.arglist, | |
syms.varargslist, | |
}: | |
features.add(Feature.POS_ONLY_ARGUMENTS) | |
elif n.type == token.COLONEQUAL: | |
features.add(Feature.ASSIGNMENT_EXPRESSIONS) | |
elif n.type == syms.decorator: | |
if len(n.children) > 1 and not is_simple_decorator_expression( | |
n.children[1] | |
): | |
features.add(Feature.RELAXED_DECORATORS) | |
elif ( | |
n.type in {syms.typedargslist, syms.arglist} | |
and n.children | |
and n.children[-1].type == token.COMMA | |
): | |
if n.type == syms.typedargslist: | |
feature = Feature.TRAILING_COMMA_IN_DEF | |
else: | |
feature = Feature.TRAILING_COMMA_IN_CALL | |
for ch in n.children: | |
if ch.type in STARS: | |
features.add(feature) | |
if ch.type == syms.argument: | |
for argch in ch.children: | |
if argch.type in STARS: | |
features.add(feature) | |
elif ( | |
n.type in {syms.return_stmt, syms.yield_expr} | |
and len(n.children) >= 2 | |
and n.children[1].type == syms.testlist_star_expr | |
and any(child.type == syms.star_expr for child in n.children[1].children) | |
): | |
features.add(Feature.UNPACKING_ON_FLOW) | |
elif ( | |
n.type == syms.annassign | |
and len(n.children) >= 4 | |
and n.children[3].type == syms.testlist_star_expr | |
): | |
features.add(Feature.ANN_ASSIGN_EXTENDED_RHS) | |
elif ( | |
n.type == syms.with_stmt | |
and len(n.children) > 2 | |
and n.children[1].type == syms.atom | |
): | |
atom_children = n.children[1].children | |
if ( | |
len(atom_children) == 3 | |
and atom_children[0].type == token.LPAR | |
and atom_children[1].type == syms.testlist_gexp | |
and atom_children[2].type == token.RPAR | |
): | |
features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS) | |
elif n.type == syms.match_stmt: | |
features.add(Feature.PATTERN_MATCHING) | |
elif ( | |
n.type == syms.except_clause | |
and len(n.children) >= 2 | |
and n.children[1].type == token.STAR | |
): | |
features.add(Feature.EXCEPT_STAR) | |
elif n.type in {syms.subscriptlist, syms.trailer} and any( | |
child.type == syms.star_expr for child in n.children | |
): | |
features.add(Feature.VARIADIC_GENERICS) | |
elif ( | |
n.type == syms.tname_star | |
and len(n.children) == 3 | |
and n.children[2].type == syms.star_expr | |
): | |
features.add(Feature.VARIADIC_GENERICS) | |
elif n.type in (syms.type_stmt, syms.typeparams): | |
features.add(Feature.TYPE_PARAMS) | |
return features | |
def detect_target_versions( | |
node: Node, *, future_imports: Optional[Set[str]] = None | |
) -> Set[TargetVersion]: | |
"""Detect the version to target based on the nodes used.""" | |
features = get_features_used(node, future_imports=future_imports) | |
return { | |
version for version in TargetVersion if features <= VERSION_TO_FEATURES[version] | |
} | |
def get_future_imports(node: Node) -> Set[str]: | |
"""Return a set of __future__ imports in the file.""" | |
imports: Set[str] = set() | |
def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]: | |
for child in children: | |
if isinstance(child, Leaf): | |
if child.type == token.NAME: | |
yield child.value | |
elif child.type == syms.import_as_name: | |
orig_name = child.children[0] | |
assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports" | |
assert orig_name.type == token.NAME, "Invalid syntax parsing imports" | |
yield orig_name.value | |
elif child.type == syms.import_as_names: | |
yield from get_imports_from_children(child.children) | |
else: | |
raise AssertionError("Invalid syntax parsing imports") | |
for child in node.children: | |
if child.type != syms.simple_stmt: | |
break | |
first_child = child.children[0] | |
if isinstance(first_child, Leaf): | |
# Continue looking if we see a docstring; otherwise stop. | |
if ( | |
len(child.children) == 2 | |
and first_child.type == token.STRING | |
and child.children[1].type == token.NEWLINE | |
): | |
continue | |
break | |
elif first_child.type == syms.import_from: | |
module_name = first_child.children[1] | |
if not isinstance(module_name, Leaf) or module_name.value != "__future__": | |
break | |
imports |= set(get_imports_from_children(first_child.children[3:])) | |
else: | |
break | |
return imports | |
def assert_equivalent(src: str, dst: str) -> None: | |
"""Raise AssertionError if `src` and `dst` aren't equivalent.""" | |
try: | |
src_ast = parse_ast(src) | |
except Exception as exc: | |
raise AssertionError( | |
"cannot use --safe with this file; failed to parse source file AST: " | |
f"{exc}\n" | |
"This could be caused by running Black with an older Python version " | |
"that does not support new syntax used in your source file." | |
) from exc | |
try: | |
dst_ast = parse_ast(dst) | |
except Exception as exc: | |
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) | |
raise AssertionError( | |
f"INTERNAL ERROR: Black produced invalid code: {exc}. " | |
"Please report a bug on https://github.com/psf/black/issues. " | |
f"This invalid output might be helpful: {log}" | |
) from None | |
src_ast_str = "\n".join(stringify_ast(src_ast)) | |
dst_ast_str = "\n".join(stringify_ast(dst_ast)) | |
if src_ast_str != dst_ast_str: | |
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) | |
raise AssertionError( | |
"INTERNAL ERROR: Black produced code that is not equivalent to the" | |
" source. Please report a bug on " | |
f"https://github.com/psf/black/issues. This diff might be helpful: {log}" | |
) from None | |
def assert_stable(src: str, dst: str, mode: Mode) -> None: | |
"""Raise AssertionError if `dst` reformats differently the second time.""" | |
# We shouldn't call format_str() here, because that formats the string | |
# twice and may hide a bug where we bounce back and forth between two | |
# versions. | |
newdst = _format_str_once(dst, mode=mode) | |
if dst != newdst: | |
log = dump_to_file( | |
str(mode), | |
diff(src, dst, "source", "first pass"), | |
diff(dst, newdst, "first pass", "second pass"), | |
) | |
raise AssertionError( | |
"INTERNAL ERROR: Black produced different code on the second pass of the" | |
" formatter. Please report a bug on https://github.com/psf/black/issues." | |
f" This diff might be helpful: {log}" | |
) from None | |
def nullcontext() -> Iterator[None]: | |
"""Return an empty context manager. | |
To be used like `nullcontext` in Python 3.7. | |
""" | |
yield | |
def patched_main() -> None: | |
# PyInstaller patches multiprocessing to need freeze_support() even in non-Windows | |
# environments so just assume we always need to call it if frozen. | |
if getattr(sys, "frozen", False): | |
from multiprocessing import freeze_support | |
freeze_support() | |
main() | |
if __name__ == "__main__": | |
patched_main() | |