import sys
sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv1d_transformer')
import numpy as np
import torch
import torch.nn as nn
from vq.residual_vq import ResidualVQ
from vq.module import WNConv1d, DecoderBlock, ResLSTM
from vq.alias_free_torch import *
from vq  import activations
import vq.blocks as blocks
from torch.nn import utils

from vq.bs_roformer5 import TransformerBlock
 
from torchtune.modules import RotaryPositionalEmbeddings

def init_weights(m):
    if isinstance(m, nn.Conv1d):
        nn.init.trunc_normal_(m.weight, std=0.02)
        nn.init.constant_(m.bias, 0)

class CodecDecoder(nn.Module):
    def __init__(self,
                 in_channels=1024,
                 upsample_initial_channel=1536,
                 ngf=48,
                 use_rnn=True,
                 rnn_bidirectional=False,
                 rnn_num_layers=2,
                 up_ratios=(5, 4, 4, 4, 2),
                 dilations=(1, 3, 9),
                 vq_num_quantizers=1,
                 vq_dim=2048,
                 vq_commit_weight=0.25,
                 vq_weight_init=False,
                 vq_full_commit_loss=False,
                 codebook_size=16384,
                 codebook_dim=32,
                ):
        super().__init__()
        self.hop_length = np.prod(up_ratios)
        self.ngf = ngf
        self.up_ratios = up_ratios
        
        self.quantizer = ResidualVQ(
            num_quantizers=vq_num_quantizers,
            dim=vq_dim, # double the dim for acousitc and semantic
            codebook_size=codebook_size,
            codebook_dim=codebook_dim,
            threshold_ema_dead_code=2,
            commitment=vq_commit_weight,
            weight_init=vq_weight_init,
            full_commit_loss=vq_full_commit_loss,
        )
        channels = upsample_initial_channel
        layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
        
        if use_rnn:
            layers += [
                ResLSTM(channels,
                        num_layers=rnn_num_layers,
                        bidirectional=rnn_bidirectional
                    )
            ]
        
        for i, stride in enumerate(up_ratios):
            input_dim = channels // 2**i
            output_dim = channels // 2 ** (i + 1)
            layers += [DecoderBlock(input_dim, output_dim, stride, dilations)]

        layers += [
            Activation1d(activation=activations.SnakeBeta(output_dim, alpha_logscale=True)),
            WNConv1d(output_dim, 1, kernel_size=7, padding=3),
            nn.Tanh(),
        ]

        self.model = nn.Sequential(*layers)
        
        self.reset_parameters()

    def forward(self, x, vq=True):
        if vq is True:
            x, q, commit_loss = self.quantizer(x)
            return x, q, commit_loss
        x = self.model(x)
        return x

    def vq2emb(self, vq):
        self.quantizer = self.quantizer.eval()
        x = self.quantizer.vq2emb(vq)
        return x

    def get_emb(self):
        self.quantizer = self.quantizer.eval()
        embs = self.quantizer.get_emb()
        return embs

    def inference_vq(self, vq):
        x = vq[None,:,:]
        x = self.model(x)
        return x

    def inference_0(self, x):
        x, q, loss, perp = self.quantizer(x)
        x = self.model(x)
        return x, None
    
    def inference(self, x):
        x = self.model(x)
        return x, None


    def remove_weight_norm(self):
        """Remove weight normalization module from all of the layers."""

        def _remove_weight_norm(m):
            try:
                torch.nn.utils.remove_weight_norm(m)
            except ValueError:  # this module didn't have weight norm
                return

        self.apply(_remove_weight_norm)

    def apply_weight_norm(self):
        """Apply weight normalization module from all of the layers."""

        def _apply_weight_norm(m):
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
                torch.nn.utils.weight_norm(m)

        self.apply(_apply_weight_norm)

    def reset_parameters(self):
        self.apply(init_weights)


class CodecDecoder_oobleck_Transformer(nn.Module):
    def __init__(self,
                ngf=32,
                up_ratios=(5, 4, 4, 4, 2),
                dilations=(1, 3, 9),
                vq_num_quantizers=1,
                vq_dim=1024,
                vq_commit_weight=0.25,
                vq_weight_init=False,
                vq_full_commit_loss=False,
                codebook_size=16384,
                codebook_dim=16,
                hidden_dim=1024,
                depth=12,
                heads=16,
                pos_meb_dim=64,
                ):
        super().__init__()
        self.hop_length = np.prod(up_ratios)
        self.capacity = ngf
        self.up_ratios = up_ratios
        self.hidden_dim = hidden_dim
        self.quantizer = ResidualVQ(
            num_quantizers=vq_num_quantizers,
            dim=vq_dim, # double the dim for acousitc and semantic
            codebook_size=codebook_size,
            codebook_dim=codebook_dim,
            threshold_ema_dead_code=2,
            commitment=vq_commit_weight,
            weight_init=vq_weight_init,
            full_commit_loss=vq_full_commit_loss,
        )

        time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
         
        transformer_blocks = [
            TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed)
            for _ in range(depth)
        ]        
 
        self.transformers = nn.Sequential(*transformer_blocks)

        self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)         
         
        self.conv_blocks = blocks.DilatedResidualDecoder(
            capacity=self.capacity,
            dilated_unit=self.dilated_unit,
            upsampling_unit=self.upsampling_unit,
            ratios=up_ratios,  # 逆转编码器的下采样比率
            dilations=dilations,
            pre_network_conv=self.pre_conv,
            post_network_conv=self.post_conv,
            )
            
 
        
        self.reset_parameters()

    def forward(self, x, vq=True):
        if vq is True:
            x, q, commit_loss = self.quantizer(x)
            return x, q, commit_loss
        x= self.transformers(x)
        x = self.final_layer_norm(x)
        x = x.permute(0, 2, 1)
        x = self.conv_blocks(x)
        return x

    def vq2emb(self, vq):
        self.quantizer = self.quantizer.eval()
        x = self.quantizer.vq2emb(vq)
        return x

    def get_emb(self):
        self.quantizer = self.quantizer.eval()
        embs = self.quantizer.get_emb()
        return embs

    def inference_vq(self, vq):
        x = vq[None,:,:]
        x = self.model(x)
        return x

    def inference_0(self, x):
        x, q, loss, perp = self.quantizer(x)
        x = self.model(x)
        return x, None
    
    def inference(self, x):
        x = self.model(x)
        return x, None


    def remove_weight_norm(self):
        """Remove weight normalization module from all of the layers."""

        def _remove_weight_norm(m):
            try:
                torch.nn.utils.remove_weight_norm(m)
            except ValueError:  # this module didn't have weight norm
                return

        self.apply(_remove_weight_norm)

    def apply_weight_norm(self):
        """Apply weight normalization module from all of the layers."""

        def _apply_weight_norm(m):
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
                torch.nn.utils.weight_norm(m)

        self.apply(_apply_weight_norm)

    def reset_parameters(self):
        self.apply(init_weights)

    def pre_conv(self, out_channels):
        return nn.Conv1d(in_channels=self.hidden_dim, out_channels=out_channels, kernel_size=1)

    # 定义后处理卷积层,将模型的输出映射到最终的输出通道数
    def post_conv(self,in_channels):
        return nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=1)

    def dilated_unit(self, hidden_dim, dilation):
        return blocks.DilatedConvolutionalUnit(
            hidden_dim=hidden_dim,
            dilation=dilation,
            kernel_size=3,
            activation=nn.ReLU ,
            normalization=utils.weight_norm
        )

    # 定义上采样单元
    def upsampling_unit(self,input_dim, output_dim, stride):
        return blocks.UpsamplingUnit(
            input_dim=input_dim,
            output_dim=output_dim,
            stride=stride,
            activation=nn.ReLU ,
            normalization=utils.weight_norm
        )

def main():
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # 初始化模型
    model = CodecDecoder_oobleck_Transformer().to(device)
    print("Model initialized.")

    # 创建测试输入: batch_size x in_channels x sequence_length
    batch_size = 2
    in_channels = 1024
    sequence_length = 100  # 示例长度,可以根据需要调整
    dummy_input = torch.randn(batch_size, sequence_length, in_channels).to(device)
    print(f"Dummy input shape: {dummy_input.shape}")

    # 将模型设为评估模式
    model.eval()

 
 
    output_no_vq = model(dummy_input, vq=False)
    c=1
 
if __name__ == "__main__":
    main()