File size: 1,585 Bytes
6eefbd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, TypeVar

from aiohttp.web_request import Request
from aiohttp.web_response import StreamResponse

if TYPE_CHECKING:
    F = TypeVar("F", bound=Callable[..., Any])
    middleware: Callable[[F], F]
else:
    try:
        from aiohttp.web_middlewares import middleware
    except ImportError:
        # @middleware is deprecated and its behaviour is the default since aiohttp 4.0
        # so if it doesn't exist anymore, define a no-op for forward compatibility.
        middleware = lambda x: x  # noqa: E731

Handler = Callable[[Request], Awaitable[StreamResponse]]
Middleware = Callable[[Request, Handler], Awaitable[StreamResponse]]


def cors(allow_headers: Iterable[str]) -> Middleware:
    @middleware
    async def impl(request: Request, handler: Handler) -> StreamResponse:
        is_options = request.method == "OPTIONS"
        is_preflight = is_options and "Access-Control-Request-Method" in request.headers
        if is_preflight:
            resp = StreamResponse()
        else:
            resp = await handler(request)

        origin = request.headers.get("Origin")
        if not origin:
            return resp

        resp.headers["Access-Control-Allow-Origin"] = "*"
        resp.headers["Access-Control-Expose-Headers"] = "*"
        if is_options:
            resp.headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
            resp.headers["Access-Control-Allow-Methods"] = ", ".join(
                ("OPTIONS", "POST")
            )

        return resp

    return impl