File size: 5,650 Bytes
			
			| 0a06673 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | # This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
from __future__ import annotations
import typing
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.modes import ECB
from cryptography.hazmat.primitives.constant_time import bytes_eq
def _wrap_core(
    wrapping_key: bytes,
    a: bytes,
    r: list[bytes],
) -> bytes:
    # RFC 3394 Key Wrap - 2.2.1 (index method)
    encryptor = Cipher(AES(wrapping_key), ECB()).encryptor()
    n = len(r)
    for j in range(6):
        for i in range(n):
            # every encryption operation is a discrete 16 byte chunk (because
            # AES has a 128-bit block size) and since we're using ECB it is
            # safe to reuse the encryptor for the entire operation
            b = encryptor.update(a + r[i])
            a = (
                int.from_bytes(b[:8], byteorder="big") ^ ((n * j) + i + 1)
            ).to_bytes(length=8, byteorder="big")
            r[i] = b[-8:]
    assert encryptor.finalize() == b""
    return a + b"".join(r)
def aes_key_wrap(
    wrapping_key: bytes,
    key_to_wrap: bytes,
    backend: typing.Any = None,
) -> bytes:
    if len(wrapping_key) not in [16, 24, 32]:
        raise ValueError("The wrapping key must be a valid AES key length")
    if len(key_to_wrap) < 16:
        raise ValueError("The key to wrap must be at least 16 bytes")
    if len(key_to_wrap) % 8 != 0:
        raise ValueError("The key to wrap must be a multiple of 8 bytes")
    a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
    r = [key_to_wrap[i : i + 8] for i in range(0, len(key_to_wrap), 8)]
    return _wrap_core(wrapping_key, a, r)
def _unwrap_core(
    wrapping_key: bytes,
    a: bytes,
    r: list[bytes],
) -> tuple[bytes, list[bytes]]:
    # Implement RFC 3394 Key Unwrap - 2.2.2 (index method)
    decryptor = Cipher(AES(wrapping_key), ECB()).decryptor()
    n = len(r)
    for j in reversed(range(6)):
        for i in reversed(range(n)):
            atr = (
                int.from_bytes(a, byteorder="big") ^ ((n * j) + i + 1)
            ).to_bytes(length=8, byteorder="big") + r[i]
            # every decryption operation is a discrete 16 byte chunk so
            # it is safe to reuse the decryptor for the entire operation
            b = decryptor.update(atr)
            a = b[:8]
            r[i] = b[-8:]
    assert decryptor.finalize() == b""
    return a, r
def aes_key_wrap_with_padding(
    wrapping_key: bytes,
    key_to_wrap: bytes,
    backend: typing.Any = None,
) -> bytes:
    if len(wrapping_key) not in [16, 24, 32]:
        raise ValueError("The wrapping key must be a valid AES key length")
    aiv = b"\xA6\x59\x59\xA6" + len(key_to_wrap).to_bytes(
        length=4, byteorder="big"
    )
    # pad the key to wrap if necessary
    pad = (8 - (len(key_to_wrap) % 8)) % 8
    key_to_wrap = key_to_wrap + b"\x00" * pad
    if len(key_to_wrap) == 8:
        # RFC 5649 - 4.1 - exactly 8 octets after padding
        encryptor = Cipher(AES(wrapping_key), ECB()).encryptor()
        b = encryptor.update(aiv + key_to_wrap)
        assert encryptor.finalize() == b""
        return b
    else:
        r = [key_to_wrap[i : i + 8] for i in range(0, len(key_to_wrap), 8)]
        return _wrap_core(wrapping_key, aiv, r)
def aes_key_unwrap_with_padding(
    wrapping_key: bytes,
    wrapped_key: bytes,
    backend: typing.Any = None,
) -> bytes:
    if len(wrapped_key) < 16:
        raise InvalidUnwrap("Must be at least 16 bytes")
    if len(wrapping_key) not in [16, 24, 32]:
        raise ValueError("The wrapping key must be a valid AES key length")
    if len(wrapped_key) == 16:
        # RFC 5649 - 4.2 - exactly two 64-bit blocks
        decryptor = Cipher(AES(wrapping_key), ECB()).decryptor()
        out = decryptor.update(wrapped_key)
        assert decryptor.finalize() == b""
        a = out[:8]
        data = out[8:]
        n = 1
    else:
        r = [wrapped_key[i : i + 8] for i in range(0, len(wrapped_key), 8)]
        encrypted_aiv = r.pop(0)
        n = len(r)
        a, r = _unwrap_core(wrapping_key, encrypted_aiv, r)
        data = b"".join(r)
    # 1) Check that MSB(32,A) = A65959A6.
    # 2) Check that 8*(n-1) < LSB(32,A) <= 8*n.  If so, let
    #    MLI = LSB(32,A).
    # 3) Let b = (8*n)-MLI, and then check that the rightmost b octets of
    #    the output data are zero.
    mli = int.from_bytes(a[4:], byteorder="big")
    b = (8 * n) - mli
    if (
        not bytes_eq(a[:4], b"\xa6\x59\x59\xa6")
        or not 8 * (n - 1) < mli <= 8 * n
        or (b != 0 and not bytes_eq(data[-b:], b"\x00" * b))
    ):
        raise InvalidUnwrap()
    if b == 0:
        return data
    else:
        return data[:-b]
def aes_key_unwrap(
    wrapping_key: bytes,
    wrapped_key: bytes,
    backend: typing.Any = None,
) -> bytes:
    if len(wrapped_key) < 24:
        raise InvalidUnwrap("Must be at least 24 bytes")
    if len(wrapped_key) % 8 != 0:
        raise InvalidUnwrap("The wrapped key must be a multiple of 8 bytes")
    if len(wrapping_key) not in [16, 24, 32]:
        raise ValueError("The wrapping key must be a valid AES key length")
    aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
    r = [wrapped_key[i : i + 8] for i in range(0, len(wrapped_key), 8)]
    a = r.pop(0)
    a, r = _unwrap_core(wrapping_key, a, r)
    if not bytes_eq(a, aiv):
        raise InvalidUnwrap()
    return b"".join(r)
class InvalidUnwrap(Exception):
    pass
 |