Spaces:
Paused
Paused
| import bisect | |
| import re | |
| import unicodedata | |
| from typing import Optional, Union | |
| from . import idnadata | |
| from .intranges import intranges_contain | |
| _virama_combining_class = 9 | |
| _alabel_prefix = b"xn--" | |
| _unicode_dots_re = re.compile("[\u002e\u3002\uff0e\uff61]") | |
| class IDNAError(UnicodeError): | |
| """Base exception for all IDNA-encoding related problems""" | |
| pass | |
| class IDNABidiError(IDNAError): | |
| """Exception when bidirectional requirements are not satisfied""" | |
| pass | |
| class InvalidCodepoint(IDNAError): | |
| """Exception when a disallowed or unallocated codepoint is used""" | |
| pass | |
| class InvalidCodepointContext(IDNAError): | |
| """Exception when the codepoint is not valid in the context it is used""" | |
| pass | |
| def _combining_class(cp: int) -> int: | |
| v = unicodedata.combining(chr(cp)) | |
| if v == 0: | |
| if not unicodedata.name(chr(cp)): | |
| raise ValueError("Unknown character in unicodedata") | |
| return v | |
| def _is_script(cp: str, script: str) -> bool: | |
| return intranges_contain(ord(cp), idnadata.scripts[script]) | |
| def _punycode(s: str) -> bytes: | |
| return s.encode("punycode") | |
| def _unot(s: int) -> str: | |
| return "U+{:04X}".format(s) | |
| def valid_label_length(label: Union[bytes, str]) -> bool: | |
| if len(label) > 63: | |
| return False | |
| return True | |
| def valid_string_length(label: Union[bytes, str], trailing_dot: bool) -> bool: | |
| if len(label) > (254 if trailing_dot else 253): | |
| return False | |
| return True | |
| def check_bidi(label: str, check_ltr: bool = False) -> bool: | |
| # Bidi rules should only be applied if string contains RTL characters | |
| bidi_label = False | |
| for idx, cp in enumerate(label, 1): | |
| direction = unicodedata.bidirectional(cp) | |
| if direction == "": | |
| # String likely comes from a newer version of Unicode | |
| raise IDNABidiError("Unknown directionality in label {} at position {}".format(repr(label), idx)) | |
| if direction in ["R", "AL", "AN"]: | |
| bidi_label = True | |
| if not bidi_label and not check_ltr: | |
| return True | |
| # Bidi rule 1 | |
| direction = unicodedata.bidirectional(label[0]) | |
| if direction in ["R", "AL"]: | |
| rtl = True | |
| elif direction == "L": | |
| rtl = False | |
| else: | |
| raise IDNABidiError("First codepoint in label {} must be directionality L, R or AL".format(repr(label))) | |
| valid_ending = False | |
| number_type: Optional[str] = None | |
| for idx, cp in enumerate(label, 1): | |
| direction = unicodedata.bidirectional(cp) | |
| if rtl: | |
| # Bidi rule 2 | |
| if direction not in [ | |
| "R", | |
| "AL", | |
| "AN", | |
| "EN", | |
| "ES", | |
| "CS", | |
| "ET", | |
| "ON", | |
| "BN", | |
| "NSM", | |
| ]: | |
| raise IDNABidiError("Invalid direction for codepoint at position {} in a right-to-left label".format(idx)) | |
| # Bidi rule 3 | |
| if direction in ["R", "AL", "EN", "AN"]: | |
| valid_ending = True | |
| elif direction != "NSM": | |
| valid_ending = False | |
| # Bidi rule 4 | |
| if direction in ["AN", "EN"]: | |
| if not number_type: | |
| number_type = direction | |
| else: | |
| if number_type != direction: | |
| raise IDNABidiError("Can not mix numeral types in a right-to-left label") | |
| else: | |
| # Bidi rule 5 | |
| if direction not in ["L", "EN", "ES", "CS", "ET", "ON", "BN", "NSM"]: | |
| raise IDNABidiError("Invalid direction for codepoint at position {} in a left-to-right label".format(idx)) | |
| # Bidi rule 6 | |
| if direction in ["L", "EN"]: | |
| valid_ending = True | |
| elif direction != "NSM": | |
| valid_ending = False | |
| if not valid_ending: | |
| raise IDNABidiError("Label ends with illegal codepoint directionality") | |
| return True | |
| def check_initial_combiner(label: str) -> bool: | |
| if unicodedata.category(label[0])[0] == "M": | |
| raise IDNAError("Label begins with an illegal combining character") | |
| return True | |
| def check_hyphen_ok(label: str) -> bool: | |
| if label[2:4] == "--": | |
| raise IDNAError("Label has disallowed hyphens in 3rd and 4th position") | |
| if label[0] == "-" or label[-1] == "-": | |
| raise IDNAError("Label must not start or end with a hyphen") | |
| return True | |
| def check_nfc(label: str) -> None: | |
| if unicodedata.normalize("NFC", label) != label: | |
| raise IDNAError("Label must be in Normalization Form C") | |
| def valid_contextj(label: str, pos: int) -> bool: | |
| cp_value = ord(label[pos]) | |
| if cp_value == 0x200C: | |
| if pos > 0: | |
| if _combining_class(ord(label[pos - 1])) == _virama_combining_class: | |
| return True | |
| ok = False | |
| for i in range(pos - 1, -1, -1): | |
| joining_type = idnadata.joining_types.get(ord(label[i])) | |
| if joining_type == ord("T"): | |
| continue | |
| elif joining_type in [ord("L"), ord("D")]: | |
| ok = True | |
| break | |
| else: | |
| break | |
| if not ok: | |
| return False | |
| ok = False | |
| for i in range(pos + 1, len(label)): | |
| joining_type = idnadata.joining_types.get(ord(label[i])) | |
| if joining_type == ord("T"): | |
| continue | |
| elif joining_type in [ord("R"), ord("D")]: | |
| ok = True | |
| break | |
| else: | |
| break | |
| return ok | |
| if cp_value == 0x200D: | |
| if pos > 0: | |
| if _combining_class(ord(label[pos - 1])) == _virama_combining_class: | |
| return True | |
| return False | |
| else: | |
| return False | |
| def valid_contexto(label: str, pos: int, exception: bool = False) -> bool: | |
| cp_value = ord(label[pos]) | |
| if cp_value == 0x00B7: | |
| if 0 < pos < len(label) - 1: | |
| if ord(label[pos - 1]) == 0x006C and ord(label[pos + 1]) == 0x006C: | |
| return True | |
| return False | |
| elif cp_value == 0x0375: | |
| if pos < len(label) - 1 and len(label) > 1: | |
| return _is_script(label[pos + 1], "Greek") | |
| return False | |
| elif cp_value == 0x05F3 or cp_value == 0x05F4: | |
| if pos > 0: | |
| return _is_script(label[pos - 1], "Hebrew") | |
| return False | |
| elif cp_value == 0x30FB: | |
| for cp in label: | |
| if cp == "\u30fb": | |
| continue | |
| if _is_script(cp, "Hiragana") or _is_script(cp, "Katakana") or _is_script(cp, "Han"): | |
| return True | |
| return False | |
| elif 0x660 <= cp_value <= 0x669: | |
| for cp in label: | |
| if 0x6F0 <= ord(cp) <= 0x06F9: | |
| return False | |
| return True | |
| elif 0x6F0 <= cp_value <= 0x6F9: | |
| for cp in label: | |
| if 0x660 <= ord(cp) <= 0x0669: | |
| return False | |
| return True | |
| return False | |
| def check_label(label: Union[str, bytes, bytearray]) -> None: | |
| if isinstance(label, (bytes, bytearray)): | |
| label = label.decode("utf-8") | |
| if len(label) == 0: | |
| raise IDNAError("Empty Label") | |
| check_nfc(label) | |
| check_hyphen_ok(label) | |
| check_initial_combiner(label) | |
| for pos, cp in enumerate(label): | |
| cp_value = ord(cp) | |
| if intranges_contain(cp_value, idnadata.codepoint_classes["PVALID"]): | |
| continue | |
| elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTJ"]): | |
| try: | |
| if not valid_contextj(label, pos): | |
| raise InvalidCodepointContext( | |
| "Joiner {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label)) | |
| ) | |
| except ValueError: | |
| raise IDNAError( | |
| "Unknown codepoint adjacent to joiner {} at position {} in {}".format( | |
| _unot(cp_value), pos + 1, repr(label) | |
| ) | |
| ) | |
| elif intranges_contain(cp_value, idnadata.codepoint_classes["CONTEXTO"]): | |
| if not valid_contexto(label, pos): | |
| raise InvalidCodepointContext( | |
| "Codepoint {} not allowed at position {} in {}".format(_unot(cp_value), pos + 1, repr(label)) | |
| ) | |
| else: | |
| raise InvalidCodepoint( | |
| "Codepoint {} at position {} of {} not allowed".format(_unot(cp_value), pos + 1, repr(label)) | |
| ) | |
| check_bidi(label) | |
| def alabel(label: str) -> bytes: | |
| try: | |
| label_bytes = label.encode("ascii") | |
| ulabel(label_bytes) | |
| if not valid_label_length(label_bytes): | |
| raise IDNAError("Label too long") | |
| return label_bytes | |
| except UnicodeEncodeError: | |
| pass | |
| check_label(label) | |
| label_bytes = _alabel_prefix + _punycode(label) | |
| if not valid_label_length(label_bytes): | |
| raise IDNAError("Label too long") | |
| return label_bytes | |
| def ulabel(label: Union[str, bytes, bytearray]) -> str: | |
| if not isinstance(label, (bytes, bytearray)): | |
| try: | |
| label_bytes = label.encode("ascii") | |
| except UnicodeEncodeError: | |
| check_label(label) | |
| return label | |
| else: | |
| label_bytes = label | |
| label_bytes = label_bytes.lower() | |
| if label_bytes.startswith(_alabel_prefix): | |
| label_bytes = label_bytes[len(_alabel_prefix) :] | |
| if not label_bytes: | |
| raise IDNAError("Malformed A-label, no Punycode eligible content found") | |
| if label_bytes.decode("ascii")[-1] == "-": | |
| raise IDNAError("A-label must not end with a hyphen") | |
| else: | |
| check_label(label_bytes) | |
| return label_bytes.decode("ascii") | |
| try: | |
| label = label_bytes.decode("punycode") | |
| except UnicodeError: | |
| raise IDNAError("Invalid A-label") | |
| check_label(label) | |
| return label | |
| def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False) -> str: | |
| """Re-map the characters in the string according to UTS46 processing.""" | |
| from .uts46data import uts46data | |
| output = "" | |
| for pos, char in enumerate(domain): | |
| code_point = ord(char) | |
| try: | |
| uts46row = uts46data[code_point if code_point < 256 else bisect.bisect_left(uts46data, (code_point, "Z")) - 1] | |
| status = uts46row[1] | |
| replacement: Optional[str] = None | |
| if len(uts46row) == 3: | |
| replacement = uts46row[2] | |
| if ( | |
| status == "V" | |
| or (status == "D" and not transitional) | |
| or (status == "3" and not std3_rules and replacement is None) | |
| ): | |
| output += char | |
| elif replacement is not None and ( | |
| status == "M" or (status == "3" and not std3_rules) or (status == "D" and transitional) | |
| ): | |
| output += replacement | |
| elif status != "I": | |
| raise IndexError() | |
| except IndexError: | |
| raise InvalidCodepoint( | |
| "Codepoint {} not allowed at position {} in {}".format(_unot(code_point), pos + 1, repr(domain)) | |
| ) | |
| return unicodedata.normalize("NFC", output) | |
| def encode( | |
| s: Union[str, bytes, bytearray], | |
| strict: bool = False, | |
| uts46: bool = False, | |
| std3_rules: bool = False, | |
| transitional: bool = False, | |
| ) -> bytes: | |
| if not isinstance(s, str): | |
| try: | |
| s = str(s, "ascii") | |
| except UnicodeDecodeError: | |
| raise IDNAError("should pass a unicode string to the function rather than a byte string.") | |
| if uts46: | |
| s = uts46_remap(s, std3_rules, transitional) | |
| trailing_dot = False | |
| result = [] | |
| if strict: | |
| labels = s.split(".") | |
| else: | |
| labels = _unicode_dots_re.split(s) | |
| if not labels or labels == [""]: | |
| raise IDNAError("Empty domain") | |
| if labels[-1] == "": | |
| del labels[-1] | |
| trailing_dot = True | |
| for label in labels: | |
| s = alabel(label) | |
| if s: | |
| result.append(s) | |
| else: | |
| raise IDNAError("Empty label") | |
| if trailing_dot: | |
| result.append(b"") | |
| s = b".".join(result) | |
| if not valid_string_length(s, trailing_dot): | |
| raise IDNAError("Domain too long") | |
| return s | |
| def decode( | |
| s: Union[str, bytes, bytearray], | |
| strict: bool = False, | |
| uts46: bool = False, | |
| std3_rules: bool = False, | |
| ) -> str: | |
| try: | |
| if not isinstance(s, str): | |
| s = str(s, "ascii") | |
| except UnicodeDecodeError: | |
| raise IDNAError("Invalid ASCII in A-label") | |
| if uts46: | |
| s = uts46_remap(s, std3_rules, False) | |
| trailing_dot = False | |
| result = [] | |
| if not strict: | |
| labels = _unicode_dots_re.split(s) | |
| else: | |
| labels = s.split(".") | |
| if not labels or labels == [""]: | |
| raise IDNAError("Empty domain") | |
| if not labels[-1]: | |
| del labels[-1] | |
| trailing_dot = True | |
| for label in labels: | |
| s = ulabel(label) | |
| if s: | |
| result.append(s) | |
| else: | |
| raise IDNAError("Empty label") | |
| if trailing_dot: | |
| result.append("") | |
| return ".".join(result) | |