# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from libc.stdint cimport int32_t, uint16_t

include 'cython_marshal.pyx'
from cassandra.buffer cimport Buffer, to_bytes, slice_buffer
from cassandra.cython_utils cimport datetime_from_timestamp

from cython.view cimport array as cython_array
from cassandra.tuple cimport tuple_new, tuple_set

import socket
from decimal import Decimal
from uuid import UUID

from cassandra import cqltypes
from cassandra import util

cdef class Deserializer:
    """Cython-based deserializer class for a cqltype"""

    def __init__(self, cqltype):
        self.cqltype = cqltype
        self.empty_binary_ok = cqltype.empty_binary_ok

    cdef deserialize(self, Buffer *buf, int protocol_version):
        raise NotImplementedError


cdef class DesBytesType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        if buf.size == 0:
            return b""
        return to_bytes(buf)

# this is to facilitate cqlsh integration, which requires bytearrays for BytesType
# It is switched in by simply overwriting DesBytesType:
# deserializers.DesBytesType = deserializers.DesBytesTypeByteArray
cdef class DesBytesTypeByteArray(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        if buf.size == 0:
            return bytearray()
        return bytearray(buf.ptr[:buf.size])

# TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html
cdef class DesDecimalType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        cdef Buffer varint_buf
        slice_buffer(buf, &varint_buf, 4, buf.size - 4)

        cdef int32_t scale = unpack_num[int32_t](buf)
        unscaled = varint_unpack(&varint_buf)

        return Decimal('%de%d' % (unscaled, -scale))


cdef class DesUUIDType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return UUID(bytes=to_bytes(buf))


cdef class DesBooleanType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        if unpack_num[int8_t](buf):
            return True
        return False


cdef class DesByteType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return unpack_num[int8_t](buf)


cdef class DesAsciiType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        if buf.size == 0:
            return ""
        return to_bytes(buf).decode('ascii')


cdef class DesFloatType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return unpack_num[float](buf)


cdef class DesDoubleType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return unpack_num[double](buf)


cdef class DesLongType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return unpack_num[int64_t](buf)


cdef class DesInt32Type(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return unpack_num[int32_t](buf)


cdef class DesIntegerType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return varint_unpack(buf)


cdef class DesInetAddressType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        cdef bytes byts = to_bytes(buf)

        # TODO: optimize inet_ntop, inet_ntoa
        if buf.size == 16:
            return util.inet_ntop(socket.AF_INET6, byts)
        else:
            # util.inet_pton could also handle, but this is faster
            # since we've already determined the AF
            return socket.inet_ntoa(byts)


cdef class DesCounterColumnType(DesLongType):
    pass


cdef class DesDateType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        cdef double timestamp = unpack_num[int64_t](buf) / 1000.0
        return datetime_from_timestamp(timestamp)


cdef class TimestampType(DesDateType):
    pass


cdef class TimeUUIDType(DesDateType):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return UUID(bytes=to_bytes(buf))


# Values of the 'date'` type are encoded as 32-bit unsigned integers
# representing a number of days with epoch (January 1st, 1970) at the center of the
# range (2^31).
EPOCH_OFFSET_DAYS = 2 ** 31

cdef class DesSimpleDateType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        days = unpack_num[uint32_t](buf) - EPOCH_OFFSET_DAYS
        return util.Date(days)


cdef class DesShortType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return unpack_num[int16_t](buf)


cdef class DesTimeType(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return util.Time(unpack_num[int64_t](buf))


cdef class DesUTF8Type(Deserializer):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        if buf.size == 0:
            return ""
        cdef val = to_bytes(buf)
        return val.decode('utf8')


cdef class DesVarcharType(DesUTF8Type):
    pass


cdef class _DesParameterizedType(Deserializer):

    cdef object subtypes
    cdef Deserializer[::1] deserializers
    cdef Py_ssize_t subtypes_len

    def __init__(self, cqltype):
        super().__init__(cqltype)
        self.subtypes = cqltype.subtypes
        self.deserializers = make_deserializers(cqltype.subtypes)
        self.subtypes_len = len(self.subtypes)


cdef class _DesSingleParamType(_DesParameterizedType):
    cdef Deserializer deserializer

    def __init__(self, cqltype):
        assert cqltype.subtypes and len(cqltype.subtypes) == 1, cqltype.subtypes
        super().__init__(cqltype)
        self.deserializer = self.deserializers[0]


#--------------------------------------------------------------------------
# List and set deserialization

cdef class DesListType(_DesSingleParamType):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        cdef uint16_t v2_and_below = 2
        cdef int32_t v3_and_above = 3

        if protocol_version >= 3:
            result = _deserialize_list_or_set[int32_t](
                v3_and_above, buf, protocol_version, self.deserializer)
        else:
            result = _deserialize_list_or_set[uint16_t](
                v2_and_below, buf, protocol_version, self.deserializer)

        return result

cdef class DesSetType(DesListType):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return util.sortedset(DesListType.deserialize(self, buf, protocol_version))


ctypedef fused itemlen_t:
    uint16_t # protocol <= v2
    int32_t  # protocol >= v3

cdef list _deserialize_list_or_set(itemlen_t dummy_version,
                                   Buffer *buf, int protocol_version,
                                   Deserializer deserializer):
    """
    Deserialize a list or set.

    The 'dummy' parameter is needed to make fused types work, so that
    we can specialize on the protocol version.
    """
    cdef Buffer itemlen_buf
    cdef Buffer elem_buf

    cdef itemlen_t numelements
    cdef int offset
    cdef list result = []

    _unpack_len[itemlen_t](buf, 0, &numelements)
    offset = sizeof(itemlen_t)
    protocol_version = max(3, protocol_version)
    for _ in range(numelements):
        subelem[itemlen_t](buf, &elem_buf, &offset, dummy_version)
        result.append(from_binary(deserializer, &elem_buf, protocol_version))

    return result


cdef inline int subelem(
        Buffer *buf, Buffer *elem_buf, int* offset, itemlen_t dummy) except -1:
    """
    Read the next element from the buffer: first read the size (in bytes) of the
    element, then fill elem_buf with a newly sliced buffer of this size (and the
    right offset).
    """
    cdef itemlen_t elemlen

    _unpack_len[itemlen_t](buf, offset[0], &elemlen)
    offset[0] += sizeof(itemlen_t)
    slice_buffer(buf, elem_buf, offset[0], elemlen)
    offset[0] += elemlen
    return 0


cdef int _unpack_len(Buffer *buf, int offset, itemlen_t *output) except -1:
    cdef Buffer itemlen_buf
    slice_buffer(buf, &itemlen_buf, offset, sizeof(itemlen_t))

    if itemlen_t is uint16_t:
        output[0] = unpack_num[uint16_t](&itemlen_buf)
    else:
        output[0] = unpack_num[int32_t](&itemlen_buf)

    return 0

#--------------------------------------------------------------------------
# Map deserialization

cdef class DesMapType(_DesParameterizedType):

    cdef Deserializer key_deserializer, val_deserializer

    def __init__(self, cqltype):
        super().__init__(cqltype)
        self.key_deserializer = self.deserializers[0]
        self.val_deserializer = self.deserializers[1]

    cdef deserialize(self, Buffer *buf, int protocol_version):
        cdef uint16_t v2_and_below = 0
        cdef int32_t v3_and_above = 0
        key_type, val_type = self.cqltype.subtypes

        if protocol_version >= 3:
            result = _deserialize_map[int32_t](
                v3_and_above, buf, protocol_version,
                self.key_deserializer, self.val_deserializer,
                key_type, val_type)
        else:
            result = _deserialize_map[uint16_t](
                v2_and_below, buf, protocol_version,
                self.key_deserializer, self.val_deserializer,
                key_type, val_type)

        return result


cdef _deserialize_map(itemlen_t dummy_version,
                      Buffer *buf, int protocol_version,
                      Deserializer key_deserializer, Deserializer val_deserializer,
                      key_type, val_type):
    cdef Buffer key_buf, val_buf
    cdef Buffer itemlen_buf

    cdef itemlen_t numelements
    cdef int offset
    cdef list result = []

    _unpack_len[itemlen_t](buf, 0, &numelements)
    offset = sizeof(itemlen_t)
    themap = util.OrderedMapSerializedKey(key_type, protocol_version)
    protocol_version = max(3, protocol_version)
    for _ in range(numelements):
        subelem[itemlen_t](buf, &key_buf, &offset, dummy_version)
        subelem[itemlen_t](buf, &val_buf, &offset, numelements)
        key = from_binary(key_deserializer, &key_buf, protocol_version)
        val = from_binary(val_deserializer, &val_buf, protocol_version)
        themap._insert_unchecked(key, to_bytes(&key_buf), val)

    return themap

#--------------------------------------------------------------------------

cdef class DesTupleType(_DesParameterizedType):

    # TODO: Use TupleRowParser to parse these tuples

    cdef deserialize(self, Buffer *buf, int protocol_version):
        cdef Py_ssize_t i, p
        cdef int32_t itemlen
        cdef tuple res = tuple_new(self.subtypes_len)
        cdef Buffer item_buf
        cdef Buffer itemlen_buf
        cdef Deserializer deserializer

        # collections inside UDTs are always encoded with at least the
        # version 3 format
        protocol_version = max(3, protocol_version)

        p = 0
        values = []
        for i in range(self.subtypes_len):
            item = None
            if p < buf.size:
                slice_buffer(buf, &itemlen_buf, p, 4)
                itemlen = unpack_num[int32_t](&itemlen_buf)
                p += 4
                if itemlen >= 0:
                    slice_buffer(buf, &item_buf, p, itemlen)
                    p += itemlen

                    deserializer = self.deserializers[i]
                    item = from_binary(deserializer, &item_buf, protocol_version)

            tuple_set(res, i, item)

        return res


cdef class DesUserType(DesTupleType):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        typ = self.cqltype
        values = DesTupleType.deserialize(self, buf, protocol_version)
        if typ.mapped_class:
            return typ.mapped_class(**dict(zip(typ.fieldnames, values)))
        elif typ.tuple_type:
            return typ.tuple_type(*values)
        else:
            return tuple(values)


cdef class DesCompositeType(_DesParameterizedType):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        cdef Py_ssize_t i, idx, start
        cdef Buffer elem_buf
        cdef int16_t element_length
        cdef Deserializer deserializer
        cdef tuple res = tuple_new(self.subtypes_len)

        idx = 0
        for i in range(self.subtypes_len):
            if not buf.size:
                # CompositeType can have missing elements at the end

                # Fill the tuple with None values and slice it
                #
                # (I'm not sure a tuple needs to be fully initialized before
                #  it can be destroyed, so play it safe)
                for j in range(i, self.subtypes_len):
                    tuple_set(res, j, None)
                res = res[:i]
                break

            element_length = unpack_num[uint16_t](buf)
            slice_buffer(buf, &elem_buf, 2, element_length)

            deserializer = self.deserializers[i]
            item = from_binary(deserializer, &elem_buf, protocol_version)
            tuple_set(res, i, item)

            # skip element length, element, and the EOC (one byte)
            start = 2 + element_length + 1
            slice_buffer(buf, buf, start, buf.size - start)

        return res


DesDynamicCompositeType = DesCompositeType


cdef class DesReversedType(_DesSingleParamType):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return from_binary(self.deserializer, buf, protocol_version)


cdef class DesFrozenType(_DesSingleParamType):
    cdef deserialize(self, Buffer *buf, int protocol_version):
        return from_binary(self.deserializer, buf, protocol_version)

#--------------------------------------------------------------------------

cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size):
    """
    Decide whether to return None or EMPTY when a buffer size is
    zero or negative. This is used by from_binary in deserializers.pxd.
    """
    if buf_size < 0:
        return None
    elif deserializer.cqltype.support_empty_values:
        return cqltypes.EMPTY
    else:
        return None

#--------------------------------------------------------------------------
# Generic deserialization

cdef class GenericDeserializer(Deserializer):
    """
    Wrap a generic datatype for deserialization
    """

    cdef deserialize(self, Buffer *buf, int protocol_version):
        return self.cqltype.deserialize(to_bytes(buf), protocol_version)

    def __repr__(self):
        return "GenericDeserializer(%s)" % (self.cqltype,)

#--------------------------------------------------------------------------
# Helper utilities

def make_deserializers(cqltypes):
    """Create an array of Deserializers for each given cqltype in cqltypes"""
    cdef Deserializer[::1] deserializers
    return obj_array([find_deserializer(ct) for ct in cqltypes])


cdef dict classes = globals()

cpdef Deserializer find_deserializer(cqltype):
    """Find a deserializer for a cqltype"""
    name = 'Des' + cqltype.__name__

    if name in globals():
        cls = classes[name]
    elif issubclass(cqltype, cqltypes.ListType):
        cls = DesListType
    elif issubclass(cqltype, cqltypes.SetType):
        cls = DesSetType
    elif issubclass(cqltype, cqltypes.MapType):
        cls = DesMapType
    elif issubclass(cqltype, cqltypes.UserType):
        # UserType is a subclass of TupleType, so should precede it
        cls = DesUserType
    elif issubclass(cqltype, cqltypes.TupleType):
        cls = DesTupleType
    elif issubclass(cqltype, cqltypes.DynamicCompositeType):
        # DynamicCompositeType is a subclass of CompositeType, so should precede it
        cls = DesDynamicCompositeType
    elif issubclass(cqltype, cqltypes.CompositeType):
        cls = DesCompositeType
    elif issubclass(cqltype, cqltypes.ReversedType):
        cls = DesReversedType
    elif issubclass(cqltype, cqltypes.FrozenType):
        cls = DesFrozenType
    else:
        cls = GenericDeserializer

    return cls(cqltype)


def obj_array(list objs):
    """Create a (Cython) array of objects given a list of objects"""
    cdef object[:] arr
    cdef Py_ssize_t i
    arr = cython_array(shape=(len(objs),), itemsize=sizeof(void *), format="O")
    # arr[:] = objs # This does not work (segmentation faults)
    for i, obj in enumerate(objs):
        arr[i] = obj
    return arr