# -*- coding: utf-8 -*-

import os
import sys
import torch
import logging
import speechbrain as sb
from speechbrain.utils.distributed import run_on_main
from hyperpyyaml import load_hyperpyyaml
from pathlib import Path
import torchaudio.transforms as T
from cv_train import ASRCV
import torchaudio
import numpy as np
import kenlm
from pyctcdecode import build_ctcdecoder
import re
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
import torch.nn as nn


# Commented out IPython magic to ensure Python compatibility.
hparams_file, run_opts, overrides = sb.parse_arguments(["hparams/train_semi.yaml"])

# If distributed_launch=True then
# create ddp_group with the right communication protocol
sb.utils.distributed.ddp_init_group(run_opts)

with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin, overrides)

# Create experiment directory
sb.create_experiment_directory(
    experiment_directory=hparams["output_folder"],
    hyperparams_to_save=hparams_file,
    overrides=overrides,
)
# Dataset prep (parsing Librispeech)

def dataio_prepare(hparams):
    """This function prepares the datasets to be used in the brain class.
    It also defines the data processing pipeline through user-defined functions."""

    # 1. Define datasets
    data_folder = hparams["data_folder"]

    train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
        csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
    )

    if hparams["sorting"] == "ascending":
        # we sort training data to speed up training and get better results.
        train_data = train_data.filtered_sorted(
            sort_key="duration",
            key_max_value={"duration": hparams["avoid_if_longer_than"]},
        )
        # when sorting do not shuffle in dataloader ! otherwise is pointless
        hparams["dataloader_options"]["shuffle"] = False

    elif hparams["sorting"] == "descending":
        train_data = train_data.filtered_sorted(
            sort_key="duration",
            reverse=True,
            key_max_value={"duration": hparams["avoid_if_longer_than"]},
        )
        # when sorting do not shuffle in dataloader ! otherwise is pointless
        hparams["dataloader_options"]["shuffle"] = False

    elif hparams["sorting"] == "random":
        pass

    else:
        raise NotImplementedError(
            "sorting must be random, ascending or descending"
        )

    valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
        csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
    )
    # We also sort the validation data so it is faster to validate
    valid_data = valid_data.filtered_sorted(sort_key="duration")
    test_datasets = {}
    for csv_file in hparams["test_csv"]:
        name = Path(csv_file).stem
        test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
            csv_path=csv_file, replacements={"data_root": data_folder}
        )
        test_datasets[name] = test_datasets[name].filtered_sorted(
            sort_key="duration"
        )

    datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]


    # 2. Define audio pipeline:
    @sb.utils.data_pipeline.takes("wav")
    @sb.utils.data_pipeline.provides("sig")
    def audio_pipeline(wav):
        info = torchaudio.info(wav)
        sig = sb.dataio.dataio.read_audio(wav)
        if len(sig.shape)>1 :
            sig = torch.mean(sig, dim=1)
        resampled = torchaudio.transforms.Resample(
            info.sample_rate, hparams["sample_rate"],
        )(sig)
        return resampled

    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
    label_encoder = sb.dataio.encoder.CTCTextEncoder()

    # 3. Define text pipeline:
    @sb.utils.data_pipeline.takes("wrd")
    @sb.utils.data_pipeline.provides(
        "wrd", "char_list", "tokens_list", "tokens"
    )
    def text_pipeline(wrd):
        yield wrd
        char_list = list(wrd)
        yield char_list
        tokens_list = label_encoder.encode_sequence(char_list)
        yield tokens_list
        tokens = torch.LongTensor(tokens_list)
        yield tokens

    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
    lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
    special_labels = {
        "blank_label": hparams["blank_index"],
        "unk_label": hparams["unk_index"]
    }
    label_encoder.load_or_create(
        path=lab_enc_file,
        from_didatasets=[train_data],
        output_key="char_list",
        special_labels=special_labels,
        sequence_input=True,
    )

    # 4. Set output:
    sb.dataio.dataset.set_output_keys(
        datasets, ["id", "sig", "wrd", "char_list", "tokens"],
    )
    return train_data, valid_data,test_datasets, label_encoder

class ASR(sb.core.Brain):
    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches to the output probabilities."""

        batch = batch.to(self.device)
        wavs, wav_lens = batch.sig
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)

        if stage == sb.Stage.TRAIN:
            if hasattr(self.hparams, "augmentation"):
                wavs = self.hparams.augmentation(wavs, wav_lens)

        # Forward pass
        feats = self.modules.wav2vec2(wavs, wav_lens)
        x = self.modules.enc(feats)
        logits = self.modules.ctc_lin(x)
        p_ctc = self.hparams.log_softmax(logits)

        return p_ctc, wav_lens

    def custom_encode(self,wavs,wav_lens) :
        wavs = wavs.to(self.device)
        if(wav_lens is not None): wav_lens.to(self.device)

        feats = self.modules.wav2vec2(wavs, wav_lens)
        x = self.modules.enc(feats)
        logits = self.modules.ctc_lin(x)
        p_ctc = self.hparams.log_softmax(logits)

        return feats,p_ctc



    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (CTC) given predictions and targets."""

        p_ctc, wav_lens = predictions

        ids = batch.id
        tokens, tokens_lens = batch.tokens

        loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)

        if stage != sb.Stage.TRAIN:
            predicted_tokens = sb.decoders.ctc_greedy_decode(
                p_ctc, wav_lens, blank_id=self.hparams.blank_index
            )
            # Decode token terms to words
            if self.hparams.use_language_modelling:
                predicted_words = []
                for logs in p_ctc:
                    text = decoder.decode(logs.detach().cpu().numpy())
                    predicted_words.append(text.split(" "))
            else:
                predicted_words = [
                    "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
                    for utt_seq in predicted_tokens
                ]
            # Convert indices to words
            target_words = [wrd.split(" ") for wrd in batch.wrd]

            self.wer_metric.append(ids, predicted_words, target_words)
            self.cer_metric.append(ids, predicted_words, target_words)

        return loss

    def fit_batch(self, batch):
        """Train the parameters given a single batch in input"""
        should_step = self.step % self.grad_accumulation_factor == 0
        # Managing automatic mixed precision
        # TOFIX: CTC fine-tuning currently is unstable
        # This is certainly due to CTC being done in fp16 instead of fp32
        if self.auto_mix_prec:
            with torch.cuda.amp.autocast():
                with self.no_sync():
                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
            with self.no_sync(not should_step):
                self.scaler.scale(
                    loss / self.grad_accumulation_factor
                ).backward()
            if should_step:

                if not self.hparams.wav2vec2.freeze:
                    self.scaler.unscale_(self.wav2vec_optimizer)
                self.scaler.unscale_(self.model_optimizer)
                if self.check_gradients(loss):
                    if not self.hparams.wav2vec2.freeze:
                        if self.optimizer_step >= self.hparams.warmup_steps:
                            self.scaler.step(self.wav2vec_optimizer)
                    self.scaler.step(self.model_optimizer)
                self.scaler.update()
                self.zero_grad()
                self.optimizer_step += 1
        else:
            # This is mandatory because HF models have a weird behavior with DDP
            # on the forward pass
            with self.no_sync():
                outputs = self.compute_forward(batch, sb.Stage.TRAIN)

            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)

            with self.no_sync(not should_step):
                (loss / self.grad_accumulation_factor).backward()
            if should_step:
                if self.check_gradients(loss):
                    if not self.hparams.wav2vec2.freeze:
                        if self.optimizer_step >= self.hparams.warmup_steps:
                            self.wav2vec_optimizer.step()
                    self.model_optimizer.step()
                self.zero_grad()
                self.optimizer_step += 1

        self.on_fit_batch_end(batch, outputs, loss, should_step)
        return loss.detach().cpu()

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        predictions = self.compute_forward(batch, stage=stage)
        with torch.no_grad():
            loss = self.compute_objectives(predictions, batch, stage=stage)
        return loss.detach()

    def on_stage_start(self, stage, epoch):
        """Gets called at the beginning of each epoch"""
        if stage != sb.Stage.TRAIN:
            self.cer_metric = self.hparams.cer_computer()
            self.wer_metric = self.hparams.error_rate_computer()

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of an epoch."""
        # Compute/store important stats
        stage_stats = {"loss": stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats
        else:
            stage_stats["CER"] = self.cer_metric.summarize("error_rate")
            stage_stats["WER"] = self.wer_metric.summarize("error_rate")

        # Perform end-of-iteration things, like annealing, logging, etc.
        if stage == sb.Stage.VALID:
            old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
                stage_stats["loss"]
            )
            old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
                stage_stats["loss"]
            )
            sb.nnet.schedulers.update_learning_rate(
                self.model_optimizer, new_lr_model
            )
            if not self.hparams.wav2vec2.freeze:
                sb.nnet.schedulers.update_learning_rate(
                    self.wav2vec_optimizer, new_lr_wav2vec
                )
            self.hparams.train_logger.log_stats(
                stats_meta={
                    "epoch": epoch,
                    "lr_model": old_lr_model,
                    "lr_wav2vec": old_lr_wav2vec,
                },
                train_stats=self.train_stats,
                valid_stats=stage_stats,
            )
            self.checkpointer.save_and_keep_only(
                meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
            )
        elif stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stage_stats,
            )
            with open(self.hparams.wer_file, "w") as w:
                self.wer_metric.write_stats(w)

    def init_optimizers(self):
        "Initializes the wav2vec2 optimizer and model optimizer"

        # If the wav2vec encoder is unfrozen, we create the optimizer
        if not self.hparams.wav2vec2.freeze:
            self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
                self.modules.wav2vec2.parameters()
            )
            if self.checkpointer is not None:
                self.checkpointer.add_recoverable(
                    "wav2vec_opt", self.wav2vec_optimizer
                )

        self.model_optimizer = self.hparams.model_opt_class(
            self.hparams.model.parameters()
        )

        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("modelopt", self.model_optimizer)

    def zero_grad(self, set_to_none=False):
        if not self.hparams.wav2vec2.freeze:
            self.wav2vec_optimizer.zero_grad(set_to_none)
        self.model_optimizer.zero_grad(set_to_none)


from speechbrain.pretrained import EncoderASR,EncoderDecoderASR
french_asr_model = EncoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-fr", savedir="pretrained_models/asr-wav2vec2-commonvoice-fr").cuda()

cvhparams_file, cvrun_opts, cvoverrides = sb.parse_arguments(["en_cv.yaml"])
with open(cvhparams_file) as cvfin:
    cvhparams = load_hyperpyyaml(cvfin, cvoverrides)
english_asr_model = ASRCV(
        modules=cvhparams["modules"],
        hparams=cvhparams,
        run_opts=cvrun_opts,
        checkpointer=cvhparams["checkpointer"],
    )
english_asr_model.checkpointer.recover_if_possible()
asr_brain = ASR(
    modules=hparams["modules"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)
asr_brain.checkpointer.recover_if_possible()
asr_brain.modules.eval()
english_asr_model.modules.eval()
french_asr_model.mods.eval()

# Commented out IPython magic to ensure Python compatibility.
# %ls

#UTILS FUNCTIOJNS
def get_size_dimensions(arr):
    size_dimensions = []
    while isinstance(arr, list):
        size_dimensions.append(len(arr))
        arr = arr[0]
    return size_dimensions

def scale_array(batch,n):
    scaled_batch = []

    for array in batch:
        if(n < len(array)): raise ValueError("Cannot scale Array down")

        repeat = round(n/len(array))+1
        scaled_length_array= []

        for i in array:
            for j in range(repeat) :
                if(len(scaled_length_array) == n): break
                scaled_length_array.append(i)

        scaled_batch.append(scaled_length_array)

    return torch.tensor(scaled_batch)


def load_paths(wavs_path):
    waveforms = []
    for path in wavs_path :
        waveform, _ = torchaudio.load(path)
        waveforms.append(waveform.squeeze(0))
    # normalize array length to the bigger arrays by pading with 0's
    padded_arrays = pad_sequence(waveforms, batch_first=True)
    return torch.tensor(padded_arrays)



device = 'cuda'
verbose = 0
#FLOW LEVEL FUNCTIONS
def merge_strategy(embeddings1, embeddings2, embeddings3,post1, post2,post3):


    post1 = post1.to(device)
    post2 = post2.to(device)
    post3 = post3.to(device)
    embeddings1 = embeddings1.to(device)
    embeddings2 = embeddings2.to(device)
    embeddings3 = embeddings3.to(device)

    posteriograms_merged = torch.cat((post1,post2,post3),dim=2)
    embeddings_merged = torch.cat((embeddings1,embeddings2,embeddings3),dim=2)

    if(verbose !=0):
      print('MERGED POST ',posteriograms_merged.shape)
      print('MERGED emb ',embeddings_merged.shape)

    return torch.cat((posteriograms_merged,embeddings_merged),dim=2).to(device)

def decode(model,wavs,wav_lens):

    with torch.no_grad():
        wav_lens = wav_lens.to(model.device)
        encoder_out = model.encode_batch(wavs, wav_lens)
        predictions = model.decoding_function(encoder_out, wav_lens)
        return predictions

def middle_layer(batch, lens):

    tn_embeddings, tn_posteriogram = asr_brain.custom_encode(batch,None)

    fr_embeddings = french_asr_model.mods.encoder.wav2vec2(batch)
    fr_posteriogram =french_asr_model.encode_batch(batch,lens)
    en_embeddings = english_asr_model.modules.wav2vec2(batch, lens)
    x = english_asr_model.modules.enc(en_embeddings)
    en_posteriogram = english_asr_model.modules.ctc_lin(x)
    #scores, en_posteriogram = english_asr_model.mods.decoder(en_embeddings ,lens)
    if(verbose !=0):
      print('[EMBEDDINGS] FR:',fr_embeddings.shape, "EN:",en_embeddings.shape, "TN:", tn_embeddings.shape)
      print('[POSTERIOGRAM] FR:',fr_posteriogram.shape, "EN:",en_posteriogram.shape,"TN:",tn_posteriogram.shape)


    bilangual_sample = merge_strategy(fr_embeddings,en_embeddings,tn_embeddings,fr_posteriogram,en_posteriogram,tn_posteriogram)
    return bilangual_sample

class Mixer(sb.core.Brain):

    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches to the output probabilities."""
        wavs, wav_lens = batch.sig
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)

        if stage == sb.Stage.TRAIN:
            if hasattr(self.hparams, "augmentation"):
                wavs = self.hparams.augmentation(wavs, wav_lens)

        multi_langual_feats = middle_layer(wavs, wav_lens)
        multi_langual_feats= multi_langual_feats.to(device)
        feats, _ = self.modules.enc(multi_langual_feats)
        logits = self.modules.ctc_lin(feats)
        p_ctc = self.hparams.log_softmax(logits)
        
        if stage!= sb.Stage.TRAIN:
            p_tokens = sb.decoders.ctc_greedy_decode(
                p_ctc, wav_lens, blank_id=self.hparams.blank_index
            )
        else : 
            p_tokens = None
        return p_ctc, wav_lens, p_tokens

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (CTC) given predictions and targets."""

        p_ctc, wav_lens , predicted_tokens= predictions

        ids = batch.id
        tokens, tokens_lens = batch.tokens

        loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)


        if stage == sb.Stage.VALID:
            predicted_words = [
                "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
                for utt_seq in predicted_tokens
            ]
            target_words = [wrd.split(" ") for wrd in batch.wrd]
            self.wer_metric.append(ids, predicted_words, target_words)
            self.cer_metric.append(ids, predicted_words, target_words)
        if stage ==sb.Stage.TEST : 
            if self.hparams.language_modelling:
                predicted_words = []
                for logs in p_ctc:
                    text = decoder.decode(logs.detach().cpu().numpy())
                    predicted_words.append(text.split(" "))
            else : 
                predicted_words = [
                    "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
                    for utt_seq in predicted_tokens
                ]

            target_words = [wrd.split(" ") for wrd in batch.wrd]
            self.wer_metric.append(ids, predicted_words, target_words)
            self.cer_metric.append(ids, predicted_words, target_words)

        return loss

    def fit_batch(self, batch):
        """Train the parameters given a single batch in input"""
        should_step = self.step % self.grad_accumulation_factor == 0
        # Managing automatic mixed precision
        # TOFIX: CTC fine-tuning currently is unstable
        # This is certainly due to CTC being done in fp16 instead of fp32
        if self.auto_mix_prec:
            with torch.cuda.amp.autocast():
                with self.no_sync():
                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
                loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
            with self.no_sync(not should_step):
                self.scaler.scale(
                    loss / self.grad_accumulation_factor
                ).backward()
            if should_step:


                self.scaler.unscale_(self.model_optimizer)
                if self.check_gradients(loss):
                    self.scaler.step(self.model_optimizer)
                self.scaler.update()
                self.zero_grad()
                self.optimizer_step += 1
        else:
            # This is mandatory because HF models have a weird behavior with DDP
            # on the forward pass
            with self.no_sync():
                outputs = self.compute_forward(batch, sb.Stage.TRAIN)

            loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)

            with self.no_sync(not should_step):
                (loss / self.grad_accumulation_factor).backward()
            if should_step:
                if self.check_gradients(loss):
                    self.model_optimizer.step()
                self.zero_grad()
                self.optimizer_step += 1

        self.on_fit_batch_end(batch, outputs, loss, should_step)
        return loss.detach().cpu()

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        predictions = self.compute_forward(batch, stage=stage)
        with torch.no_grad():
            loss = self.compute_objectives(predictions, batch, stage=stage)
        return loss.detach()

    def on_stage_start(self, stage, epoch):
        """Gets called at the beginning of each epoch"""
        if stage != sb.Stage.TRAIN:
            self.cer_metric = self.hparams.cer_computer()
            self.wer_metric = self.hparams.error_rate_computer()

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of an epoch."""
        # Compute/store important stats
        stage_stats = {"loss": stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats
        else:
            stage_stats["CER"] = self.cer_metric.summarize("error_rate")
            stage_stats["WER"] = self.wer_metric.summarize("error_rate")

        # Perform end-of-iteration things, like annealing, logging, etc.
        if stage == sb.Stage.VALID:
            old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
                stage_stats["loss"]
            )
            sb.nnet.schedulers.update_learning_rate(
                self.model_optimizer, new_lr_model
            )
            self.hparams.train_logger.log_stats(
                stats_meta={
                    "epoch": epoch,
                    "lr_model": old_lr_model,
                },
                train_stats=self.train_stats,
                valid_stats=stage_stats,
            )
            self.checkpointer.save_and_keep_only(
                meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
            )
        elif stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stage_stats,
            )
            with open(self.hparams.wer_file, "w") as w:
                self.wer_metric.write_stats(w)

    def init_optimizers(self):

        self.model_optimizer = self.hparams.model_opt_class(
            self.hparams.model.parameters()
        )

        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("modelopt", self.model_optimizer)

    def zero_grad(self, set_to_none=False):

        self.model_optimizer.zero_grad(set_to_none)


hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])

# If distributed_launch=True then
# create ddp_group with the right communication protocol
sb.utils.distributed.ddp_init_group(run_opts)

with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin, overrides)

# Create experiment directory
sb.create_experiment_directory(
    experiment_directory=hparams["output_folder"],
    hyperparams_to_save=hparams_file,
    overrides=overrides,
)
def read_labels_file(labels_file):
    with open(labels_file, "r",encoding="utf-8") as lf:
        lines = lf.read().splitlines()
        division = "==="
        numbers = {}
        for line in lines :
            if division in line :
                break
            string, number = line.split("=>")
            number = int(number)
            string = string[1:-2]
            numbers[number] = string
        return [numbers[x] for x in range(len(numbers))]
train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
        hparams
    )


labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
labels = [""] + labels[1:-1] + ["1"] 
if hparams["language_modelling"]:
    decoder = build_ctcdecoder(
        labels,
        kenlm_model_path=hparams["ngram_lm_path"],  # either .arpa or .bin file
        alpha=0.5,  # tuned on a val set
        beta=1,  # tuned on a val set
    )




mixer = Mixer(
    modules=hparams["modules"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)
mixer.tokenizer = label_encoder


mixer.fit(
    mixer.hparams.epoch_counter,
    train_data,
    valid_data,
    train_loader_kwargs=hparams["dataloader_options"],
    valid_loader_kwargs=hparams["test_dataloader_options"],
)
print(test_datasets.keys())
for k in test_datasets.keys():  # keys are test_clean, test_other etc
    mixer.hparams.wer_file = os.path.join(
        hparams["output_folder"], "wer_{}.txt".format(k)
    )
    mixer.evaluate(
        test_datasets[k], test_loader_kwargs=hparams["test_dataloader_options"]
        )