Spaces:
Runtime error
Runtime error
import asyncio | |
import logging | |
from concurrent.futures import Executor, ProcessPoolExecutor | |
from datetime import datetime, timezone | |
from functools import partial | |
from multiprocessing import freeze_support | |
from typing import Set, Tuple | |
try: | |
from aiohttp import web | |
from .middlewares import cors | |
except ImportError as ie: | |
raise ImportError( | |
f"aiohttp dependency is not installed: {ie}. " | |
+ "Please re-install black with the '[d]' extra install " | |
+ "to obtain aiohttp_cors: `pip install black[d]`" | |
) from None | |
import click | |
import black | |
from _black_version import version as __version__ | |
from black.concurrency import maybe_install_uvloop | |
# This is used internally by tests to shut down the server prematurely | |
_stop_signal = asyncio.Event() | |
# Request headers | |
PROTOCOL_VERSION_HEADER = "X-Protocol-Version" | |
LINE_LENGTH_HEADER = "X-Line-Length" | |
PYTHON_VARIANT_HEADER = "X-Python-Variant" | |
SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line" | |
SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization" | |
SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma" | |
PREVIEW = "X-Preview" | |
FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe" | |
DIFF_HEADER = "X-Diff" | |
BLACK_HEADERS = [ | |
PROTOCOL_VERSION_HEADER, | |
LINE_LENGTH_HEADER, | |
PYTHON_VARIANT_HEADER, | |
SKIP_SOURCE_FIRST_LINE, | |
SKIP_STRING_NORMALIZATION_HEADER, | |
SKIP_MAGIC_TRAILING_COMMA, | |
PREVIEW, | |
FAST_OR_SAFE_HEADER, | |
DIFF_HEADER, | |
] | |
# Response headers | |
BLACK_VERSION_HEADER = "X-Black-Version" | |
class InvalidVariantHeader(Exception): | |
pass | |
def main(bind_host: str, bind_port: int) -> None: | |
logging.basicConfig(level=logging.INFO) | |
app = make_app() | |
ver = black.__version__ | |
black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}") | |
web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None) | |
def make_app() -> web.Application: | |
app = web.Application( | |
middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))] | |
) | |
executor = ProcessPoolExecutor() | |
app.add_routes([web.post("/", partial(handle, executor=executor))]) | |
return app | |
async def handle(request: web.Request, executor: Executor) -> web.Response: | |
headers = {BLACK_VERSION_HEADER: __version__} | |
try: | |
if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1": | |
return web.Response( | |
status=501, text="This server only supports protocol version 1" | |
) | |
try: | |
line_length = int( | |
request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH) | |
) | |
except ValueError: | |
return web.Response(status=400, text="Invalid line length header value") | |
if PYTHON_VARIANT_HEADER in request.headers: | |
value = request.headers[PYTHON_VARIANT_HEADER] | |
try: | |
pyi, versions = parse_python_variant_header(value) | |
except InvalidVariantHeader as e: | |
return web.Response( | |
status=400, | |
text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}", | |
) | |
else: | |
pyi = False | |
versions = set() | |
skip_string_normalization = bool( | |
request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False) | |
) | |
skip_magic_trailing_comma = bool( | |
request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False) | |
) | |
skip_source_first_line = bool( | |
request.headers.get(SKIP_SOURCE_FIRST_LINE, False) | |
) | |
preview = bool(request.headers.get(PREVIEW, False)) | |
fast = False | |
if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast": | |
fast = True | |
mode = black.FileMode( | |
target_versions=versions, | |
is_pyi=pyi, | |
line_length=line_length, | |
skip_source_first_line=skip_source_first_line, | |
string_normalization=not skip_string_normalization, | |
magic_trailing_comma=not skip_magic_trailing_comma, | |
preview=preview, | |
) | |
req_bytes = await request.content.read() | |
charset = request.charset if request.charset is not None else "utf8" | |
req_str = req_bytes.decode(charset) | |
then = datetime.now(timezone.utc) | |
header = "" | |
if skip_source_first_line: | |
first_newline_position: int = req_str.find("\n") + 1 | |
header = req_str[:first_newline_position] | |
req_str = req_str[first_newline_position:] | |
loop = asyncio.get_event_loop() | |
formatted_str = await loop.run_in_executor( | |
executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode) | |
) | |
# Preserve CRLF line endings | |
if req_str[req_str.find("\n") - 1] == "\r": | |
formatted_str = formatted_str.replace("\n", "\r\n") | |
# If, after swapping line endings, nothing changed, then say so | |
if formatted_str == req_str: | |
raise black.NothingChanged | |
# Put the source first line back | |
req_str = header + req_str | |
formatted_str = header + formatted_str | |
# Only output the diff in the HTTP response | |
only_diff = bool(request.headers.get(DIFF_HEADER, False)) | |
if only_diff: | |
now = datetime.now(timezone.utc) | |
src_name = f"In\t{then}" | |
dst_name = f"Out\t{now}" | |
loop = asyncio.get_event_loop() | |
formatted_str = await loop.run_in_executor( | |
executor, | |
partial(black.diff, req_str, formatted_str, src_name, dst_name), | |
) | |
return web.Response( | |
content_type=request.content_type, | |
charset=charset, | |
headers=headers, | |
text=formatted_str, | |
) | |
except black.NothingChanged: | |
return web.Response(status=204, headers=headers) | |
except black.InvalidInput as e: | |
return web.Response(status=400, headers=headers, text=str(e)) | |
except Exception as e: | |
logging.exception("Exception during handling a request") | |
return web.Response(status=500, headers=headers, text=str(e)) | |
def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]: | |
if value == "pyi": | |
return True, set() | |
else: | |
versions = set() | |
for version in value.split(","): | |
if version.startswith("py"): | |
version = version[len("py") :] | |
if "." in version: | |
major_str, *rest = version.split(".") | |
else: | |
major_str = version[0] | |
rest = [version[1:]] if len(version) > 1 else [] | |
try: | |
major = int(major_str) | |
if major not in (2, 3): | |
raise InvalidVariantHeader("major version must be 2 or 3") | |
if len(rest) > 0: | |
minor = int(rest[0]) | |
if major == 2: | |
raise InvalidVariantHeader("Python 2 is not supported") | |
else: | |
# Default to lowest supported minor version. | |
minor = 7 if major == 2 else 3 | |
version_str = f"PY{major}{minor}" | |
if major == 3 and not hasattr(black.TargetVersion, version_str): | |
raise InvalidVariantHeader(f"3.{minor} is not supported") | |
versions.add(black.TargetVersion[version_str]) | |
except (KeyError, ValueError): | |
raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None | |
return False, versions | |
def patched_main() -> None: | |
maybe_install_uvloop() | |
freeze_support() | |
main() | |
if __name__ == "__main__": | |
patched_main() | |