Spaces:
Running
Running
File size: 7,849 Bytes
47b2311 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
from __future__ import annotations
import re
import typing as t
from dataclasses import dataclass
from dataclasses import field
from .converters import ValidationError
from .exceptions import NoMatch
from .exceptions import RequestAliasRedirect
from .exceptions import RequestPath
from .rules import Rule
from .rules import RulePart
class SlashRequired(Exception):
pass
@dataclass
class State:
"""A representation of a rule state.
This includes the *rules* that correspond to the state and the
possible *static* and *dynamic* transitions to the next state.
"""
dynamic: list[tuple[RulePart, State]] = field(default_factory=list)
rules: list[Rule] = field(default_factory=list)
static: dict[str, State] = field(default_factory=dict)
class StateMachineMatcher:
def __init__(self, merge_slashes: bool) -> None:
self._root = State()
self.merge_slashes = merge_slashes
def add(self, rule: Rule) -> None:
state = self._root
for part in rule._parts:
if part.static:
state.static.setdefault(part.content, State())
state = state.static[part.content]
else:
for test_part, new_state in state.dynamic:
if test_part == part:
state = new_state
break
else:
new_state = State()
state.dynamic.append((part, new_state))
state = new_state
state.rules.append(rule)
def update(self) -> None:
# For every state the dynamic transitions should be sorted by
# the weight of the transition
state = self._root
def _update_state(state: State) -> None:
state.dynamic.sort(key=lambda entry: entry[0].weight)
for new_state in state.static.values():
_update_state(new_state)
for _, new_state in state.dynamic:
_update_state(new_state)
_update_state(state)
def match(
self, domain: str, path: str, method: str, websocket: bool
) -> tuple[Rule, t.MutableMapping[str, t.Any]]:
# To match to a rule we need to start at the root state and
# try to follow the transitions until we find a match, or find
# there is no transition to follow.
have_match_for = set()
websocket_mismatch = False
def _match(
state: State, parts: list[str], values: list[str]
) -> tuple[Rule, list[str]] | None:
# This function is meant to be called recursively, and will attempt
# to match the head part to the state's transitions.
nonlocal have_match_for, websocket_mismatch
# The base case is when all parts have been matched via
# transitions. Hence if there is a rule with methods &
# websocket that work return it and the dynamic values
# extracted.
if parts == []:
for rule in state.rules:
if rule.methods is not None and method not in rule.methods:
have_match_for.update(rule.methods)
elif rule.websocket != websocket:
websocket_mismatch = True
else:
return rule, values
# Test if there is a match with this path with a
# trailing slash, if so raise an exception to report
# that matching is possible with an additional slash
if "" in state.static:
for rule in state.static[""].rules:
if websocket == rule.websocket and (
rule.methods is None or method in rule.methods
):
if rule.strict_slashes:
raise SlashRequired()
else:
return rule, values
return None
part = parts[0]
# To match this part try the static transitions first
if part in state.static:
rv = _match(state.static[part], parts[1:], values)
if rv is not None:
return rv
# No match via the static transitions, so try the dynamic
# ones.
for test_part, new_state in state.dynamic:
target = part
remaining = parts[1:]
# A final part indicates a transition that always
# consumes the remaining parts i.e. transitions to a
# final state.
if test_part.final:
target = "/".join(parts)
remaining = []
match = re.compile(test_part.content).match(target)
if match is not None:
if test_part.suffixed:
# If a part_isolating=False part has a slash suffix, remove the
# suffix from the match and check for the slash redirect next.
suffix = match.groups()[-1]
if suffix == "/":
remaining = [""]
converter_groups = sorted(
match.groupdict().items(), key=lambda entry: entry[0]
)
groups = [
value
for key, value in converter_groups
if key[:11] == "__werkzeug_"
]
rv = _match(new_state, remaining, values + groups)
if rv is not None:
return rv
# If there is no match and the only part left is a
# trailing slash ("") consider rules that aren't
# strict-slashes as these should match if there is a final
# slash part.
if parts == [""]:
for rule in state.rules:
if rule.strict_slashes:
continue
if rule.methods is not None and method not in rule.methods:
have_match_for.update(rule.methods)
elif rule.websocket != websocket:
websocket_mismatch = True
else:
return rule, values
return None
try:
rv = _match(self._root, [domain, *path.split("/")], [])
except SlashRequired:
raise RequestPath(f"{path}/") from None
if self.merge_slashes and rv is None:
# Try to match again, but with slashes merged
path = re.sub("/{2,}?", "/", path)
try:
rv = _match(self._root, [domain, *path.split("/")], [])
except SlashRequired:
raise RequestPath(f"{path}/") from None
if rv is None or rv[0].merge_slashes is False:
raise NoMatch(have_match_for, websocket_mismatch)
else:
raise RequestPath(f"{path}")
elif rv is not None:
rule, values = rv
result = {}
for name, value in zip(rule._converters.keys(), values):
try:
value = rule._converters[name].to_python(value)
except ValidationError:
raise NoMatch(have_match_for, websocket_mismatch) from None
result[str(name)] = value
if rule.defaults:
result.update(rule.defaults)
if rule.alias and rule.map.redirect_defaults:
raise RequestAliasRedirect(result, rule.endpoint)
return rule, result
raise NoMatch(have_match_for, websocket_mismatch)
|