import argparse
import os

import numpy as np
import pandas as pd
import torch
from omegaconf import OmegaConf
from pydub import AudioSegment
from tqdm import trange
from transformers import (
    AutoFeatureExtractor,
    BertForSequenceClassification,
    BertJapaneseTokenizer,
    Wav2Vec2ForXVector,
)


class Embeder:
    def __init__(self, config):
        self.config = OmegaConf.load(config)
        self.df = pd.read_csv(config.path_csv)
        self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(
            "anton-l/wav2vec2-base-superb-sv"
        )
        self.audio_model = Wav2Vec2ForXVector.from_pretrained(
            "anton-l/wav2vec2-base-superb-sv"
        )
        self.text_tokenizer = BertJapaneseTokenizer.from_pretrained(
            "cl-tohoku/bert-base-japanese-whole-word-masking"
        )
        self.text_model = BertForSequenceClassification.from_pretrained(
            "cl-tohoku/bert-base-japanese-whole-word-masking",
            num_labels=2,
            output_attentions=False,
            output_hidden_states=True,
        ).eval()

    def run(self):
        self._create_audio_embed()
        self._create_text_embed()

    def _create_audio_embed(self):
        audio_embed = None
        idx = []
        for i in trange(len(self.df)):
            audio = []
            song = AudioSegment.from_wav(
                os.path.join(
                    self.config.path_data,
                    "new_" + self.df.iloc[i]["filename"].replace(".mp3", ".wav"),
                )
            )
            song = np.array(song.get_array_of_samples(), dtype="float")
            audio.append(song)
            inputs = self.audio_feature_extractor(
                audio,
                sampling_rate=self.config.sample_rate,
                return_tensors="pt",
                padding=True,
            )
            try:
                with torch.no_grad():
                    embeddings = self.audio_model(**inputs).embeddings
                audio_embed = (
                    embeddings
                    if audio_embed is None
                    else torch.concatenate([audio_embed, embeddings])
                )
            except Exception:
                idx.append(i)

        audio_embed = torch.nn.functional.normalize(audio_embed, dim=-1).cpu()
        self.clean_and_save_data(audio_embed, idx)
        self.df = self.df.drop(index=idx)
        self.df.to_csv(self.config.path_csv, index=False)

    def _create_text_embed(self):
        text_embed = None
        for i in range(len(self.df)):
            sentence = self.df.iloc[i]["filename"].replace(".mp3", "")
            tokenized_text = self.text_tokenizer.tokenize(sentence)
            indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text)
            tokens_tensor = torch.tensor([indexed_tokens])
            with torch.no_grad():
                all_encoder_layers = self.text_model(tokens_tensor)
            embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1)
            text_embed = (
                embedding
                if text_embed is None
                else torch.concatenate([text_embed, embedding])
            )
        text_embed = torch.nn.functional.normalize(text_embed, dim=-1).cpu()
        torch.save(text_embed, self.config.path_text_embedding)

    def clean_and_save_data(self, audio_embed, idx):
        clean_embed = None
        for i in range(1, len(audio_embed)):
            if i in idx:
                continue
            else:
                clean_embed = (
                    audio_embed[i].reshape(1, -1)
                    if clean_embed is None
                    else torch.concatenate([clean_embed, audio_embed[i].reshape(1, -1)])
                )
        torch.save(clean_embed, self.config.path_audio_embedding)


def argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c",
        "--config",
        type=str,
        default="config.yaml",
        help="File path for config file.",
    )
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = argparser()
    embeder = Embeder(args.config)
    embeder.run()