Spaces:
Runtime error
Runtime error
from . import idnadata | |
import bisect | |
import unicodedata | |
import re | |
from typing import Union, Optional | |
from .intranges import intranges_contain | |
_virama_combining_class = 9 | |
_alabel_prefix = b'xn--' | |
_unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]') | |
class IDNAError(UnicodeError): | |
""" Base exception for all IDNA-encoding related problems """ | |
pass | |
class IDNABidiError(IDNAError): | |
""" Exception when bidirectional requirements are not satisfied """ | |
pass | |
class InvalidCodepoint(IDNAError): | |
""" Exception when a disallowed or unallocated codepoint is used """ | |
pass | |
class InvalidCodepointContext(IDNAError): | |
""" Exception when the codepoint is not valid in the context it is used """ | |
pass | |
def _combining_class(cp: int) -> int: | |
v = unicodedata.combining(chr(cp)) | |
if v == 0: | |
if not unicodedata.name(chr(cp)): | |
raise ValueError('Unknown character in unicodedata') | |
return v | |
def _is_script(cp: str, script: str) -> bool: | |
return intranges_contain(ord(cp), idnadata.scripts[script]) | |
def _punycode(s: str) -> bytes: | |
return s.encode('punycode') | |
def _unot(s: int) -> str: | |
return 'U+{:04X}'.format(s) | |
def valid_label_length(label: Union[bytes, str]) -> bool: | |
if len(label) > 63: | |
return False | |
return True | |
def valid_string_length(label: Union[bytes, str], trailing_dot: bool) -> bool: | |
if len(label) > (254 if trailing_dot else 253): | |
return False | |
return True | |
def check_bidi(label: str, check_ltr: bool = False) -> bool: | |
# Bidi rules should only be applied if string contains RTL characters | |
bidi_label = False | |
for (idx, cp) in enumerate(label, 1): | |
direction = unicodedata.bidirectional(cp) | |
if direction == '': | |
# String likely comes from a newer version of Unicode | |
raise IDNABidiError('Unknown directionality in label {} at position {}'.format(repr(label), idx)) | |
if direction in ['R', 'AL', 'AN']: | |
bidi_label = True | |
if not bidi_label and not check_ltr: | |
return True | |
# Bidi rule 1 | |
direction = unicodedata.bidirectional(label[0]) | |
if direction in ['R', 'AL']: | |
rtl = True | |
elif direction == 'L': | |
rtl = False | |
else: | |
raise IDNABidiError('First codepoint in label {} must be directionality L, R or AL'.format(repr(label))) | |
valid_ending = False | |
number_type = None # type: Optional[str] | |
for (idx, cp) in enumerate(label, 1): | |
direction = unicodedata.bidirectional(cp) | |
if rtl: | |
# Bidi rule 2 | |
if not direction in ['R', 'AL', 'AN', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']: | |
raise IDNABidiError('Invalid direction for codepoint at position {} in a right-to-left label'.format(idx)) | |
# Bidi rule 3 | |
if direction in ['R', 'AL', 'EN', 'AN']: | |
valid_ending = True | |
elif direction != 'NSM': | |
valid_ending = False | |
# Bidi rule 4 | |
if direction in ['AN', 'EN']: | |
if not number_type: | |
number_type = direction | |
else: | |
if number_type != direction: | |
raise IDNABidiError('Can not mix numeral types in a right-to-left label') | |
else: | |
# Bidi rule 5 | |
if not direction in ['L', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']: | |
raise IDNABidiError('Invalid direction for codepoint at position {} in a left-to-right label'.format(idx)) | |
# Bidi rule 6 | |
if direction in ['L', 'EN']: | |
valid_ending = True | |
elif direction != 'NSM': | |
valid_ending = False | |
if not valid_ending: | |
raise IDNABidiError('Label ends with illegal codepoint directionality') | |
return True | |
def check_initial_combiner(label: str) -> bool: | |
if unicodedata.category(label[0])[0] == 'M': | |
raise IDNAError('Label begins with an illegal combining character') | |
return True | |
def check_hyphen_ok(label: str) -> bool: | |
if label[2:4] == '--': | |
raise IDNAError('Label has disallowed hyphens in 3rd and 4th position') | |
if label[0] == '-' or label[-1] == '-': | |
raise IDNAError('Label must not start or end with a hyphen') | |
return True | |
def check_nfc(label: str) -> None: | |
if unicodedata.normalize('NFC', label) != label: | |
raise IDNAError('Label must be in Normalization Form C') | |
def valid_contextj(label: str, pos: int) -> bool: | |
cp_value = ord(label[pos]) | |
if cp_value == 0x200c: | |
if pos > 0: | |
if _combining_class(ord(label[pos - 1])) == _virama_combining_class: | |
return True | |
ok = False | |
for i in range(pos-1, -1, -1): | |
joining_type = idnadata.joining_types.get(ord(label[i])) | |
if joining_type == ord('T'): | |
continue | |
if joining_type in [ord('L'), ord('D')]: | |
ok = True | |
break | |
if not ok: | |
return False | |
ok = False | |
for i in range(pos+1, len(label)): | |
joining_type = idnadata.joining_types.get(ord(label[i])) | |
if joining_type == ord('T'): | |
continue | |
if joining_type in [ord('R'), ord('D')]: | |
ok = True | |
break | |
return ok | |
if cp_value == 0x200d: | |
if pos > 0: | |
if _combining_class(ord(label[pos - 1])) == _virama_combining_class: | |
return True | |
return False | |
else: | |
return False | |
def valid_contexto(label: str, pos: int, exception: bool = False) -> bool: | |
cp_value = ord(label[pos]) | |
if cp_value == 0x00b7: | |
if 0 < pos < len(label)-1: | |
if ord(label[pos - 1]) == 0x006c and ord(label[pos + 1]) == 0x006c: | |
return True | |
return False | |
elif cp_value == 0x0375: | |
if pos < len(label)-1 and len(label) > 1: | |
return _is_script(label[pos + 1], 'Greek') | |
return False | |
elif cp_value == 0x05f3 or cp_value == 0x05f4: | |
if pos > 0: | |
return _is_script(label[pos - 1], 'Hebrew') | |
return False | |
elif cp_value == 0x30fb: | |
for cp in label: | |
if cp == '\u30fb': | |
continue | |
if _is_script(cp, 'Hiragana') or _is_script(cp, 'Katakana') or _is_script(cp, 'Han'): | |
return True | |
return False | |
elif 0x660 <= cp_value <= 0x669: | |
for cp in label: | |
if 0x6f0 <= ord(cp) <= 0x06f9: | |
return False | |
return True | |
elif 0x6f0 <= cp_value <= 0x6f9: | |
for cp in label: | |
if 0x660 <= ord(cp) <= 0x0669: | |
return False | |
return True | |
return False | |
def check_label(label: Union[str, bytes, bytearray]) -> None: | |
if isinstance(label, (bytes, bytearray)): | |
label = label.decode('utf-8') | |
if len(label) == 0: | |
raise IDNAError('Empty Label') | |
check_nfc(label) | |
check_hyphen_ok(label) | |
check_initial_combiner(label) | |
for (pos, cp) in enumerate(label): | |
cp_value = ord(cp) | |
if intranges_contain(cp_value, idnadata.codepoint_classes['PVALID']): | |
continue | |
elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTJ']): | |
try: | |
if not valid_contextj(label, pos): | |
raise InvalidCodepointContext('Joiner {} not allowed at position {} in {}'.format( | |
_unot(cp_value), pos+1, repr(label))) | |
except ValueError: | |
raise IDNAError('Unknown codepoint adjacent to joiner {} at position {} in {}'.format( | |
_unot(cp_value), pos+1, repr(label))) | |
elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTO']): | |
if not valid_contexto(label, pos): | |
raise InvalidCodepointContext('Codepoint {} not allowed at position {} in {}'.format(_unot(cp_value), pos+1, repr(label))) | |
else: | |
raise InvalidCodepoint('Codepoint {} at position {} of {} not allowed'.format(_unot(cp_value), pos+1, repr(label))) | |
check_bidi(label) | |
def alabel(label: str) -> bytes: | |
try: | |
label_bytes = label.encode('ascii') | |
ulabel(label_bytes) | |
if not valid_label_length(label_bytes): | |
raise IDNAError('Label too long') | |
return label_bytes | |
except UnicodeEncodeError: | |
pass | |
if not label: | |
raise IDNAError('No Input') | |
label = str(label) | |
check_label(label) | |
label_bytes = _punycode(label) | |
label_bytes = _alabel_prefix + label_bytes | |
if not valid_label_length(label_bytes): | |
raise IDNAError('Label too long') | |
return label_bytes | |
def ulabel(label: Union[str, bytes, bytearray]) -> str: | |
if not isinstance(label, (bytes, bytearray)): | |
try: | |
label_bytes = label.encode('ascii') | |
except UnicodeEncodeError: | |
check_label(label) | |
return label | |
else: | |
label_bytes = label | |
label_bytes = label_bytes.lower() | |
if label_bytes.startswith(_alabel_prefix): | |
label_bytes = label_bytes[len(_alabel_prefix):] | |
if not label_bytes: | |
raise IDNAError('Malformed A-label, no Punycode eligible content found') | |
if label_bytes.decode('ascii')[-1] == '-': | |
raise IDNAError('A-label must not end with a hyphen') | |
else: | |
check_label(label_bytes) | |
return label_bytes.decode('ascii') | |
try: | |
label = label_bytes.decode('punycode') | |
except UnicodeError: | |
raise IDNAError('Invalid A-label') | |
check_label(label) | |
return label | |
def uts46_remap(domain: str, std3_rules: bool = True, transitional: bool = False) -> str: | |
"""Re-map the characters in the string according to UTS46 processing.""" | |
from .uts46data import uts46data | |
output = '' | |
for pos, char in enumerate(domain): | |
code_point = ord(char) | |
try: | |
uts46row = uts46data[code_point if code_point < 256 else | |
bisect.bisect_left(uts46data, (code_point, 'Z')) - 1] | |
status = uts46row[1] | |
replacement = None # type: Optional[str] | |
if len(uts46row) == 3: | |
replacement = uts46row[2] # type: ignore | |
if (status == 'V' or | |
(status == 'D' and not transitional) or | |
(status == '3' and not std3_rules and replacement is None)): | |
output += char | |
elif replacement is not None and (status == 'M' or | |
(status == '3' and not std3_rules) or | |
(status == 'D' and transitional)): | |
output += replacement | |
elif status != 'I': | |
raise IndexError() | |
except IndexError: | |
raise InvalidCodepoint( | |
'Codepoint {} not allowed at position {} in {}'.format( | |
_unot(code_point), pos + 1, repr(domain))) | |
return unicodedata.normalize('NFC', output) | |
def encode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False, transitional: bool = False) -> bytes: | |
if isinstance(s, (bytes, bytearray)): | |
try: | |
s = s.decode('ascii') | |
except UnicodeDecodeError: | |
raise IDNAError('should pass a unicode string to the function rather than a byte string.') | |
if uts46: | |
s = uts46_remap(s, std3_rules, transitional) | |
trailing_dot = False | |
result = [] | |
if strict: | |
labels = s.split('.') | |
else: | |
labels = _unicode_dots_re.split(s) | |
if not labels or labels == ['']: | |
raise IDNAError('Empty domain') | |
if labels[-1] == '': | |
del labels[-1] | |
trailing_dot = True | |
for label in labels: | |
s = alabel(label) | |
if s: | |
result.append(s) | |
else: | |
raise IDNAError('Empty label') | |
if trailing_dot: | |
result.append(b'') | |
s = b'.'.join(result) | |
if not valid_string_length(s, trailing_dot): | |
raise IDNAError('Domain too long') | |
return s | |
def decode(s: Union[str, bytes, bytearray], strict: bool = False, uts46: bool = False, std3_rules: bool = False) -> str: | |
try: | |
if isinstance(s, (bytes, bytearray)): | |
s = s.decode('ascii') | |
except UnicodeDecodeError: | |
raise IDNAError('Invalid ASCII in A-label') | |
if uts46: | |
s = uts46_remap(s, std3_rules, False) | |
trailing_dot = False | |
result = [] | |
if not strict: | |
labels = _unicode_dots_re.split(s) | |
else: | |
labels = s.split('.') | |
if not labels or labels == ['']: | |
raise IDNAError('Empty domain') | |
if not labels[-1]: | |
del labels[-1] | |
trailing_dot = True | |
for label in labels: | |
s = ulabel(label) | |
if s: | |
result.append(s) | |
else: | |
raise IDNAError('Empty label') | |
if trailing_dot: | |
result.append('') | |
return '.'.join(result) | |