Spaces:
No application file
No application file
| import inspect | |
| from contextlib import AsyncExitStack, contextmanager | |
| from copy import deepcopy | |
| from typing import ( | |
| Any, | |
| Callable, | |
| Coroutine, | |
| Dict, | |
| ForwardRef, | |
| List, | |
| Mapping, | |
| Optional, | |
| Sequence, | |
| Tuple, | |
| Type, | |
| Union, | |
| cast, | |
| ) | |
| import anyio | |
| from fastapi import params | |
| from fastapi._compat import ( | |
| PYDANTIC_V2, | |
| ErrorWrapper, | |
| ModelField, | |
| Required, | |
| Undefined, | |
| _regenerate_error_with_loc, | |
| copy_field_info, | |
| create_body_model, | |
| evaluate_forwardref, | |
| field_annotation_is_scalar, | |
| get_annotation_from_field_info, | |
| get_missing_field_error, | |
| is_bytes_field, | |
| is_bytes_sequence_field, | |
| is_scalar_field, | |
| is_scalar_sequence_field, | |
| is_sequence_field, | |
| is_uploadfile_or_nonable_uploadfile_annotation, | |
| is_uploadfile_sequence_annotation, | |
| lenient_issubclass, | |
| sequence_types, | |
| serialize_sequence_value, | |
| value_is_sequence, | |
| ) | |
| from fastapi.background import BackgroundTasks | |
| from fastapi.concurrency import ( | |
| asynccontextmanager, | |
| contextmanager_in_threadpool, | |
| ) | |
| from fastapi.dependencies.models import Dependant, SecurityRequirement | |
| from fastapi.logger import logger | |
| from fastapi.security.base import SecurityBase | |
| from fastapi.security.oauth2 import OAuth2, SecurityScopes | |
| from fastapi.security.open_id_connect_url import OpenIdConnect | |
| from fastapi.utils import create_response_field, get_path_param_names | |
| from pydantic.fields import FieldInfo | |
| from starlette.background import BackgroundTasks as StarletteBackgroundTasks | |
| from starlette.concurrency import run_in_threadpool | |
| from starlette.datastructures import FormData, Headers, QueryParams, UploadFile | |
| from starlette.requests import HTTPConnection, Request | |
| from starlette.responses import Response | |
| from starlette.websockets import WebSocket | |
| from typing_extensions import Annotated, get_args, get_origin | |
| multipart_not_installed_error = ( | |
| 'Form data requires "python-multipart" to be installed. \n' | |
| 'You can install "python-multipart" with: \n\n' | |
| "pip install python-multipart\n" | |
| ) | |
| multipart_incorrect_install_error = ( | |
| 'Form data requires "python-multipart" to be installed. ' | |
| 'It seems you installed "multipart" instead. \n' | |
| 'You can remove "multipart" with: \n\n' | |
| "pip uninstall multipart\n\n" | |
| 'And then install "python-multipart" with: \n\n' | |
| "pip install python-multipart\n" | |
| ) | |
| def check_file_field(field: ModelField) -> None: | |
| field_info = field.field_info | |
| if isinstance(field_info, params.Form): | |
| try: | |
| # __version__ is available in both multiparts, and can be mocked | |
| from multipart import __version__ # type: ignore | |
| assert __version__ | |
| try: | |
| # parse_options_header is only available in the right multipart | |
| from multipart.multipart import parse_options_header # type: ignore | |
| assert parse_options_header | |
| except ImportError: | |
| logger.error(multipart_incorrect_install_error) | |
| raise RuntimeError(multipart_incorrect_install_error) from None | |
| except ImportError: | |
| logger.error(multipart_not_installed_error) | |
| raise RuntimeError(multipart_not_installed_error) from None | |
| def get_param_sub_dependant( | |
| *, | |
| param_name: str, | |
| depends: params.Depends, | |
| path: str, | |
| security_scopes: Optional[List[str]] = None, | |
| ) -> Dependant: | |
| assert depends.dependency | |
| return get_sub_dependant( | |
| depends=depends, | |
| dependency=depends.dependency, | |
| path=path, | |
| name=param_name, | |
| security_scopes=security_scopes, | |
| ) | |
| def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: | |
| assert callable( | |
| depends.dependency | |
| ), "A parameter-less dependency must have a callable dependency" | |
| return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) | |
| def get_sub_dependant( | |
| *, | |
| depends: params.Depends, | |
| dependency: Callable[..., Any], | |
| path: str, | |
| name: Optional[str] = None, | |
| security_scopes: Optional[List[str]] = None, | |
| ) -> Dependant: | |
| security_requirement = None | |
| security_scopes = security_scopes or [] | |
| if isinstance(depends, params.Security): | |
| dependency_scopes = depends.scopes | |
| security_scopes.extend(dependency_scopes) | |
| if isinstance(dependency, SecurityBase): | |
| use_scopes: List[str] = [] | |
| if isinstance(dependency, (OAuth2, OpenIdConnect)): | |
| use_scopes = security_scopes | |
| security_requirement = SecurityRequirement( | |
| security_scheme=dependency, scopes=use_scopes | |
| ) | |
| sub_dependant = get_dependant( | |
| path=path, | |
| call=dependency, | |
| name=name, | |
| security_scopes=security_scopes, | |
| use_cache=depends.use_cache, | |
| ) | |
| if security_requirement: | |
| sub_dependant.security_requirements.append(security_requirement) | |
| return sub_dependant | |
| CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] | |
| def get_flat_dependant( | |
| dependant: Dependant, | |
| *, | |
| skip_repeats: bool = False, | |
| visited: Optional[List[CacheKey]] = None, | |
| ) -> Dependant: | |
| if visited is None: | |
| visited = [] | |
| visited.append(dependant.cache_key) | |
| flat_dependant = Dependant( | |
| path_params=dependant.path_params.copy(), | |
| query_params=dependant.query_params.copy(), | |
| header_params=dependant.header_params.copy(), | |
| cookie_params=dependant.cookie_params.copy(), | |
| body_params=dependant.body_params.copy(), | |
| security_schemes=dependant.security_requirements.copy(), | |
| use_cache=dependant.use_cache, | |
| path=dependant.path, | |
| ) | |
| for sub_dependant in dependant.dependencies: | |
| if skip_repeats and sub_dependant.cache_key in visited: | |
| continue | |
| flat_sub = get_flat_dependant( | |
| sub_dependant, skip_repeats=skip_repeats, visited=visited | |
| ) | |
| flat_dependant.path_params.extend(flat_sub.path_params) | |
| flat_dependant.query_params.extend(flat_sub.query_params) | |
| flat_dependant.header_params.extend(flat_sub.header_params) | |
| flat_dependant.cookie_params.extend(flat_sub.cookie_params) | |
| flat_dependant.body_params.extend(flat_sub.body_params) | |
| flat_dependant.security_requirements.extend(flat_sub.security_requirements) | |
| return flat_dependant | |
| def get_flat_params(dependant: Dependant) -> List[ModelField]: | |
| flat_dependant = get_flat_dependant(dependant, skip_repeats=True) | |
| return ( | |
| flat_dependant.path_params | |
| + flat_dependant.query_params | |
| + flat_dependant.header_params | |
| + flat_dependant.cookie_params | |
| ) | |
| def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: | |
| signature = inspect.signature(call) | |
| globalns = getattr(call, "__globals__", {}) | |
| typed_params = [ | |
| inspect.Parameter( | |
| name=param.name, | |
| kind=param.kind, | |
| default=param.default, | |
| annotation=get_typed_annotation(param.annotation, globalns), | |
| ) | |
| for param in signature.parameters.values() | |
| ] | |
| typed_signature = inspect.Signature(typed_params) | |
| return typed_signature | |
| def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: | |
| if isinstance(annotation, str): | |
| annotation = ForwardRef(annotation) | |
| annotation = evaluate_forwardref(annotation, globalns, globalns) | |
| return annotation | |
| def get_typed_return_annotation(call: Callable[..., Any]) -> Any: | |
| signature = inspect.signature(call) | |
| annotation = signature.return_annotation | |
| if annotation is inspect.Signature.empty: | |
| return None | |
| globalns = getattr(call, "__globals__", {}) | |
| return get_typed_annotation(annotation, globalns) | |
| def get_dependant( | |
| *, | |
| path: str, | |
| call: Callable[..., Any], | |
| name: Optional[str] = None, | |
| security_scopes: Optional[List[str]] = None, | |
| use_cache: bool = True, | |
| ) -> Dependant: | |
| path_param_names = get_path_param_names(path) | |
| endpoint_signature = get_typed_signature(call) | |
| signature_params = endpoint_signature.parameters | |
| dependant = Dependant( | |
| call=call, | |
| name=name, | |
| path=path, | |
| security_scopes=security_scopes, | |
| use_cache=use_cache, | |
| ) | |
| for param_name, param in signature_params.items(): | |
| is_path_param = param_name in path_param_names | |
| type_annotation, depends, param_field = analyze_param( | |
| param_name=param_name, | |
| annotation=param.annotation, | |
| value=param.default, | |
| is_path_param=is_path_param, | |
| ) | |
| if depends is not None: | |
| sub_dependant = get_param_sub_dependant( | |
| param_name=param_name, | |
| depends=depends, | |
| path=path, | |
| security_scopes=security_scopes, | |
| ) | |
| dependant.dependencies.append(sub_dependant) | |
| continue | |
| if add_non_field_param_to_dependency( | |
| param_name=param_name, | |
| type_annotation=type_annotation, | |
| dependant=dependant, | |
| ): | |
| assert ( | |
| param_field is None | |
| ), f"Cannot specify multiple FastAPI annotations for {param_name!r}" | |
| continue | |
| assert param_field is not None | |
| if is_body_param(param_field=param_field, is_path_param=is_path_param): | |
| dependant.body_params.append(param_field) | |
| else: | |
| add_param_to_fields(field=param_field, dependant=dependant) | |
| return dependant | |
| def add_non_field_param_to_dependency( | |
| *, param_name: str, type_annotation: Any, dependant: Dependant | |
| ) -> Optional[bool]: | |
| if lenient_issubclass(type_annotation, Request): | |
| dependant.request_param_name = param_name | |
| return True | |
| elif lenient_issubclass(type_annotation, WebSocket): | |
| dependant.websocket_param_name = param_name | |
| return True | |
| elif lenient_issubclass(type_annotation, HTTPConnection): | |
| dependant.http_connection_param_name = param_name | |
| return True | |
| elif lenient_issubclass(type_annotation, Response): | |
| dependant.response_param_name = param_name | |
| return True | |
| elif lenient_issubclass(type_annotation, StarletteBackgroundTasks): | |
| dependant.background_tasks_param_name = param_name | |
| return True | |
| elif lenient_issubclass(type_annotation, SecurityScopes): | |
| dependant.security_scopes_param_name = param_name | |
| return True | |
| return None | |
| def analyze_param( | |
| *, | |
| param_name: str, | |
| annotation: Any, | |
| value: Any, | |
| is_path_param: bool, | |
| ) -> Tuple[Any, Optional[params.Depends], Optional[ModelField]]: | |
| field_info = None | |
| depends = None | |
| type_annotation: Any = Any | |
| use_annotation: Any = Any | |
| if annotation is not inspect.Signature.empty: | |
| use_annotation = annotation | |
| type_annotation = annotation | |
| if get_origin(use_annotation) is Annotated: | |
| annotated_args = get_args(annotation) | |
| type_annotation = annotated_args[0] | |
| fastapi_annotations = [ | |
| arg | |
| for arg in annotated_args[1:] | |
| if isinstance(arg, (FieldInfo, params.Depends)) | |
| ] | |
| fastapi_specific_annotations = [ | |
| arg | |
| for arg in fastapi_annotations | |
| if isinstance(arg, (params.Param, params.Body, params.Depends)) | |
| ] | |
| if fastapi_specific_annotations: | |
| fastapi_annotation: Union[ | |
| FieldInfo, params.Depends, None | |
| ] = fastapi_specific_annotations[-1] | |
| else: | |
| fastapi_annotation = None | |
| if isinstance(fastapi_annotation, FieldInfo): | |
| # Copy `field_info` because we mutate `field_info.default` below. | |
| field_info = copy_field_info( | |
| field_info=fastapi_annotation, annotation=use_annotation | |
| ) | |
| assert field_info.default is Undefined or field_info.default is Required, ( | |
| f"`{field_info.__class__.__name__}` default value cannot be set in" | |
| f" `Annotated` for {param_name!r}. Set the default value with `=` instead." | |
| ) | |
| if value is not inspect.Signature.empty: | |
| assert not is_path_param, "Path parameters cannot have default values" | |
| field_info.default = value | |
| else: | |
| field_info.default = Required | |
| elif isinstance(fastapi_annotation, params.Depends): | |
| depends = fastapi_annotation | |
| if isinstance(value, params.Depends): | |
| assert depends is None, ( | |
| "Cannot specify `Depends` in `Annotated` and default value" | |
| f" together for {param_name!r}" | |
| ) | |
| assert field_info is None, ( | |
| "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a" | |
| f" default value together for {param_name!r}" | |
| ) | |
| depends = value | |
| elif isinstance(value, FieldInfo): | |
| assert field_info is None, ( | |
| "Cannot specify FastAPI annotations in `Annotated` and default value" | |
| f" together for {param_name!r}" | |
| ) | |
| field_info = value | |
| if PYDANTIC_V2: | |
| field_info.annotation = type_annotation | |
| if depends is not None and depends.dependency is None: | |
| depends.dependency = type_annotation | |
| if lenient_issubclass( | |
| type_annotation, | |
| ( | |
| Request, | |
| WebSocket, | |
| HTTPConnection, | |
| Response, | |
| StarletteBackgroundTasks, | |
| SecurityScopes, | |
| ), | |
| ): | |
| assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" | |
| assert ( | |
| field_info is None | |
| ), f"Cannot specify FastAPI annotation for type {type_annotation!r}" | |
| elif field_info is None and depends is None: | |
| default_value = value if value is not inspect.Signature.empty else Required | |
| if is_path_param: | |
| # We might check here that `default_value is Required`, but the fact is that the same | |
| # parameter might sometimes be a path parameter and sometimes not. See | |
| # `tests/test_infer_param_optionality.py` for an example. | |
| field_info = params.Path(annotation=use_annotation) | |
| elif is_uploadfile_or_nonable_uploadfile_annotation( | |
| type_annotation | |
| ) or is_uploadfile_sequence_annotation(type_annotation): | |
| field_info = params.File(annotation=use_annotation, default=default_value) | |
| elif not field_annotation_is_scalar(annotation=type_annotation): | |
| field_info = params.Body(annotation=use_annotation, default=default_value) | |
| else: | |
| field_info = params.Query(annotation=use_annotation, default=default_value) | |
| field = None | |
| if field_info is not None: | |
| if is_path_param: | |
| assert isinstance(field_info, params.Path), ( | |
| f"Cannot use `{field_info.__class__.__name__}` for path param" | |
| f" {param_name!r}" | |
| ) | |
| elif ( | |
| isinstance(field_info, params.Param) | |
| and getattr(field_info, "in_", None) is None | |
| ): | |
| field_info.in_ = params.ParamTypes.query | |
| use_annotation_from_field_info = get_annotation_from_field_info( | |
| use_annotation, | |
| field_info, | |
| param_name, | |
| ) | |
| if not field_info.alias and getattr(field_info, "convert_underscores", None): | |
| alias = param_name.replace("_", "-") | |
| else: | |
| alias = field_info.alias or param_name | |
| field_info.alias = alias | |
| field = create_response_field( | |
| name=param_name, | |
| type_=use_annotation_from_field_info, | |
| default=field_info.default, | |
| alias=alias, | |
| required=field_info.default in (Required, Undefined), | |
| field_info=field_info, | |
| ) | |
| return type_annotation, depends, field | |
| def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: | |
| if is_path_param: | |
| assert is_scalar_field( | |
| field=param_field | |
| ), "Path params must be of one of the supported types" | |
| return False | |
| elif is_scalar_field(field=param_field): | |
| return False | |
| elif isinstance( | |
| param_field.field_info, (params.Query, params.Header) | |
| ) and is_scalar_sequence_field(param_field): | |
| return False | |
| else: | |
| assert isinstance( | |
| param_field.field_info, params.Body | |
| ), f"Param: {param_field.name} can only be a request body, using Body()" | |
| return True | |
| def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: | |
| field_info = field.field_info | |
| field_info_in = getattr(field_info, "in_", None) | |
| if field_info_in == params.ParamTypes.path: | |
| dependant.path_params.append(field) | |
| elif field_info_in == params.ParamTypes.query: | |
| dependant.query_params.append(field) | |
| elif field_info_in == params.ParamTypes.header: | |
| dependant.header_params.append(field) | |
| else: | |
| assert ( | |
| field_info_in == params.ParamTypes.cookie | |
| ), f"non-body parameters must be in path, query, header or cookie: {field.name}" | |
| dependant.cookie_params.append(field) | |
| def is_coroutine_callable(call: Callable[..., Any]) -> bool: | |
| if inspect.isroutine(call): | |
| return inspect.iscoroutinefunction(call) | |
| if inspect.isclass(call): | |
| return False | |
| dunder_call = getattr(call, "__call__", None) # noqa: B004 | |
| return inspect.iscoroutinefunction(dunder_call) | |
| def is_async_gen_callable(call: Callable[..., Any]) -> bool: | |
| if inspect.isasyncgenfunction(call): | |
| return True | |
| dunder_call = getattr(call, "__call__", None) # noqa: B004 | |
| return inspect.isasyncgenfunction(dunder_call) | |
| def is_gen_callable(call: Callable[..., Any]) -> bool: | |
| if inspect.isgeneratorfunction(call): | |
| return True | |
| dunder_call = getattr(call, "__call__", None) # noqa: B004 | |
| return inspect.isgeneratorfunction(dunder_call) | |
| async def solve_generator( | |
| *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] | |
| ) -> Any: | |
| if is_gen_callable(call): | |
| cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) | |
| elif is_async_gen_callable(call): | |
| cm = asynccontextmanager(call)(**sub_values) | |
| return await stack.enter_async_context(cm) | |
| async def solve_dependencies( | |
| *, | |
| request: Union[Request, WebSocket], | |
| dependant: Dependant, | |
| body: Optional[Union[Dict[str, Any], FormData]] = None, | |
| background_tasks: Optional[StarletteBackgroundTasks] = None, | |
| response: Optional[Response] = None, | |
| dependency_overrides_provider: Optional[Any] = None, | |
| dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, | |
| async_exit_stack: AsyncExitStack, | |
| ) -> Tuple[ | |
| Dict[str, Any], | |
| List[Any], | |
| Optional[StarletteBackgroundTasks], | |
| Response, | |
| Dict[Tuple[Callable[..., Any], Tuple[str]], Any], | |
| ]: | |
| values: Dict[str, Any] = {} | |
| errors: List[Any] = [] | |
| if response is None: | |
| response = Response() | |
| del response.headers["content-length"] | |
| response.status_code = None # type: ignore | |
| dependency_cache = dependency_cache or {} | |
| sub_dependant: Dependant | |
| for sub_dependant in dependant.dependencies: | |
| sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) | |
| sub_dependant.cache_key = cast( | |
| Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key | |
| ) | |
| call = sub_dependant.call | |
| use_sub_dependant = sub_dependant | |
| if ( | |
| dependency_overrides_provider | |
| and dependency_overrides_provider.dependency_overrides | |
| ): | |
| original_call = sub_dependant.call | |
| call = getattr( | |
| dependency_overrides_provider, "dependency_overrides", {} | |
| ).get(original_call, original_call) | |
| use_path: str = sub_dependant.path # type: ignore | |
| use_sub_dependant = get_dependant( | |
| path=use_path, | |
| call=call, | |
| name=sub_dependant.name, | |
| security_scopes=sub_dependant.security_scopes, | |
| ) | |
| solved_result = await solve_dependencies( | |
| request=request, | |
| dependant=use_sub_dependant, | |
| body=body, | |
| background_tasks=background_tasks, | |
| response=response, | |
| dependency_overrides_provider=dependency_overrides_provider, | |
| dependency_cache=dependency_cache, | |
| async_exit_stack=async_exit_stack, | |
| ) | |
| ( | |
| sub_values, | |
| sub_errors, | |
| background_tasks, | |
| _, # the subdependency returns the same response we have | |
| sub_dependency_cache, | |
| ) = solved_result | |
| dependency_cache.update(sub_dependency_cache) | |
| if sub_errors: | |
| errors.extend(sub_errors) | |
| continue | |
| if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: | |
| solved = dependency_cache[sub_dependant.cache_key] | |
| elif is_gen_callable(call) or is_async_gen_callable(call): | |
| solved = await solve_generator( | |
| call=call, stack=async_exit_stack, sub_values=sub_values | |
| ) | |
| elif is_coroutine_callable(call): | |
| solved = await call(**sub_values) | |
| else: | |
| solved = await run_in_threadpool(call, **sub_values) | |
| if sub_dependant.name is not None: | |
| values[sub_dependant.name] = solved | |
| if sub_dependant.cache_key not in dependency_cache: | |
| dependency_cache[sub_dependant.cache_key] = solved | |
| path_values, path_errors = request_params_to_args( | |
| dependant.path_params, request.path_params | |
| ) | |
| query_values, query_errors = request_params_to_args( | |
| dependant.query_params, request.query_params | |
| ) | |
| header_values, header_errors = request_params_to_args( | |
| dependant.header_params, request.headers | |
| ) | |
| cookie_values, cookie_errors = request_params_to_args( | |
| dependant.cookie_params, request.cookies | |
| ) | |
| values.update(path_values) | |
| values.update(query_values) | |
| values.update(header_values) | |
| values.update(cookie_values) | |
| errors += path_errors + query_errors + header_errors + cookie_errors | |
| if dependant.body_params: | |
| ( | |
| body_values, | |
| body_errors, | |
| ) = await request_body_to_args( # body_params checked above | |
| required_params=dependant.body_params, received_body=body | |
| ) | |
| values.update(body_values) | |
| errors.extend(body_errors) | |
| if dependant.http_connection_param_name: | |
| values[dependant.http_connection_param_name] = request | |
| if dependant.request_param_name and isinstance(request, Request): | |
| values[dependant.request_param_name] = request | |
| elif dependant.websocket_param_name and isinstance(request, WebSocket): | |
| values[dependant.websocket_param_name] = request | |
| if dependant.background_tasks_param_name: | |
| if background_tasks is None: | |
| background_tasks = BackgroundTasks() | |
| values[dependant.background_tasks_param_name] = background_tasks | |
| if dependant.response_param_name: | |
| values[dependant.response_param_name] = response | |
| if dependant.security_scopes_param_name: | |
| values[dependant.security_scopes_param_name] = SecurityScopes( | |
| scopes=dependant.security_scopes | |
| ) | |
| return values, errors, background_tasks, response, dependency_cache | |
| def request_params_to_args( | |
| required_params: Sequence[ModelField], | |
| received_params: Union[Mapping[str, Any], QueryParams, Headers], | |
| ) -> Tuple[Dict[str, Any], List[Any]]: | |
| values = {} | |
| errors = [] | |
| for field in required_params: | |
| if is_scalar_sequence_field(field) and isinstance( | |
| received_params, (QueryParams, Headers) | |
| ): | |
| value = received_params.getlist(field.alias) or field.default | |
| else: | |
| value = received_params.get(field.alias) | |
| field_info = field.field_info | |
| assert isinstance( | |
| field_info, params.Param | |
| ), "Params must be subclasses of Param" | |
| loc = (field_info.in_.value, field.alias) | |
| if value is None: | |
| if field.required: | |
| errors.append(get_missing_field_error(loc=loc)) | |
| else: | |
| values[field.name] = deepcopy(field.default) | |
| continue | |
| v_, errors_ = field.validate(value, values, loc=loc) | |
| if isinstance(errors_, ErrorWrapper): | |
| errors.append(errors_) | |
| elif isinstance(errors_, list): | |
| new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) | |
| errors.extend(new_errors) | |
| else: | |
| values[field.name] = v_ | |
| return values, errors | |
| async def request_body_to_args( | |
| required_params: List[ModelField], | |
| received_body: Optional[Union[Dict[str, Any], FormData]], | |
| ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: | |
| values = {} | |
| errors: List[Dict[str, Any]] = [] | |
| if required_params: | |
| field = required_params[0] | |
| field_info = field.field_info | |
| embed = getattr(field_info, "embed", None) | |
| field_alias_omitted = len(required_params) == 1 and not embed | |
| if field_alias_omitted: | |
| received_body = {field.alias: received_body} | |
| for field in required_params: | |
| loc: Tuple[str, ...] | |
| if field_alias_omitted: | |
| loc = ("body",) | |
| else: | |
| loc = ("body", field.alias) | |
| value: Optional[Any] = None | |
| if received_body is not None: | |
| if (is_sequence_field(field)) and isinstance(received_body, FormData): | |
| value = received_body.getlist(field.alias) | |
| else: | |
| try: | |
| value = received_body.get(field.alias) | |
| except AttributeError: | |
| errors.append(get_missing_field_error(loc)) | |
| continue | |
| if ( | |
| value is None | |
| or (isinstance(field_info, params.Form) and value == "") | |
| or ( | |
| isinstance(field_info, params.Form) | |
| and is_sequence_field(field) | |
| and len(value) == 0 | |
| ) | |
| ): | |
| if field.required: | |
| errors.append(get_missing_field_error(loc)) | |
| else: | |
| values[field.name] = deepcopy(field.default) | |
| continue | |
| if ( | |
| isinstance(field_info, params.File) | |
| and is_bytes_field(field) | |
| and isinstance(value, UploadFile) | |
| ): | |
| value = await value.read() | |
| elif ( | |
| is_bytes_sequence_field(field) | |
| and isinstance(field_info, params.File) | |
| and value_is_sequence(value) | |
| ): | |
| # For types | |
| assert isinstance(value, sequence_types) # type: ignore[arg-type] | |
| results: List[Union[bytes, str]] = [] | |
| async def process_fn( | |
| fn: Callable[[], Coroutine[Any, Any, Any]] | |
| ) -> None: | |
| result = await fn() | |
| results.append(result) # noqa: B023 | |
| async with anyio.create_task_group() as tg: | |
| for sub_value in value: | |
| tg.start_soon(process_fn, sub_value.read) | |
| value = serialize_sequence_value(field=field, value=results) | |
| v_, errors_ = field.validate(value, values, loc=loc) | |
| if isinstance(errors_, list): | |
| errors.extend(errors_) | |
| elif errors_: | |
| errors.append(errors_) | |
| else: | |
| values[field.name] = v_ | |
| return values, errors | |
| def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: | |
| flat_dependant = get_flat_dependant(dependant) | |
| if not flat_dependant.body_params: | |
| return None | |
| first_param = flat_dependant.body_params[0] | |
| field_info = first_param.field_info | |
| embed = getattr(field_info, "embed", None) | |
| body_param_names_set = {param.name for param in flat_dependant.body_params} | |
| if len(body_param_names_set) == 1 and not embed: | |
| check_file_field(first_param) | |
| return first_param | |
| # If one field requires to embed, all have to be embedded | |
| # in case a sub-dependency is evaluated with a single unique body field | |
| # That is combined (embedded) with other body fields | |
| for param in flat_dependant.body_params: | |
| setattr(param.field_info, "embed", True) # noqa: B010 | |
| model_name = "Body_" + name | |
| BodyModel = create_body_model( | |
| fields=flat_dependant.body_params, model_name=model_name | |
| ) | |
| required = any(True for f in flat_dependant.body_params if f.required) | |
| BodyFieldInfo_kwargs: Dict[str, Any] = { | |
| "annotation": BodyModel, | |
| "alias": "body", | |
| } | |
| if not required: | |
| BodyFieldInfo_kwargs["default"] = None | |
| if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params): | |
| BodyFieldInfo: Type[params.Body] = params.File | |
| elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params): | |
| BodyFieldInfo = params.Form | |
| else: | |
| BodyFieldInfo = params.Body | |
| body_param_media_types = [ | |
| f.field_info.media_type | |
| for f in flat_dependant.body_params | |
| if isinstance(f.field_info, params.Body) | |
| ] | |
| if len(set(body_param_media_types)) == 1: | |
| BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] | |
| final_field = create_response_field( | |
| name="body", | |
| type_=BodyModel, | |
| required=required, | |
| alias="body", | |
| field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), | |
| ) | |
| check_file_field(final_field) | |
| return final_field | |