# -*- coding: utf-8 -*-

# Copyright 2020 MINH ANH (@dathudeptrai)
#  MIT License (https://opensource.org/licenses/MIT)

"""Tensorflow Layer modules complatible with pytorch."""

import tensorflow as tf


class TFReflectionPad1d(tf.keras.layers.Layer):
    """Tensorflow ReflectionPad1d module."""

    def __init__(self, padding_size):
        """Initialize TFReflectionPad1d module.

        Args:
            padding_size (int): Padding size.

        """
        super(TFReflectionPad1d, self).__init__()
        self.padding_size = padding_size

    @tf.function
    def call(self, x):
        """Calculate forward propagation.

        Args:
            x (Tensor): Input tensor (B, T, 1, C).

        Returns:
            Tensor: Padded tensor (B, T + 2 * padding_size, 1, C).

        """
        return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT")


class TFConvTranspose1d(tf.keras.layers.Layer):
    """Tensorflow ConvTranspose1d module."""

    def __init__(self, channels, kernel_size, stride, padding):
        """Initialize TFConvTranspose1d( module.

        Args:
            channels (int): Number of channels.
            kernel_size (int): kernel size.
            strides (int): Stride width.
            padding (str): Padding type ("same" or "valid").

        """
        super(TFConvTranspose1d, self).__init__()
        self.conv1d_transpose = tf.keras.layers.Conv2DTranspose(
            filters=channels,
            kernel_size=(kernel_size, 1),
            strides=(stride, 1),
            padding=padding,
        )

    @tf.function
    def call(self, x):
        """Calculate forward propagation.

        Args:
            x (Tensor): Input tensor (B, T, 1, C).

        Returns:
            Tensors: Output tensor (B, T', 1, C').

        """
        x = self.conv1d_transpose(x)
        return x


class TFResidualStack(tf.keras.layers.Layer):
    """Tensorflow ResidualStack module."""

    def __init__(self,
                 kernel_size,
                 channels,
                 dilation,
                 bias,
                 nonlinear_activation,
                 nonlinear_activation_params,
                 padding,
                 ):
        """Initialize TFResidualStack module.

        Args:
            kernel_size (int): Kernel size.
            channles (int): Number of channels.
            dilation (int): Dilation ine.
            bias (bool): Whether to add bias parameter in convolution layers.
            nonlinear_activation (str): Activation function module name.
            nonlinear_activation_params (dict): Hyperparameters for activation function.
            padding (str): Padding type ("same" or "valid").

        """
        super(TFResidualStack, self).__init__()
        self.block = [
            getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
            TFReflectionPad1d(dilation),
            tf.keras.layers.Conv2D(
                filters=channels,
                kernel_size=(kernel_size, 1),
                dilation_rate=(dilation, 1),
                use_bias=bias,
                padding="valid",
            ),
            getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
            tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
        ]
        self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)

    @tf.function
    def call(self, x):
        """Calculate forward propagation.

        Args:
            x (Tensor): Input tensor (B, T, 1, C).

        Returns:
            Tensor: Output tensor (B, T, 1, C).

        """
        _x = tf.identity(x)
        for i, layer in enumerate(self.block):
            _x = layer(_x)
        shortcut = self.shortcut(x)
        return shortcut + _x