"""
File: model_translation.py

Description: 
   Loading models for text translations

Author: Didier Guillevic
Date: 2024-03-16
"""
import spaces

import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
from transformers import BitsAndBytesConfig

from model_spacy import nlp_xx as model_spacy

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5
)

# The 100 languages supported by the facebook/m2m100_418M model
# https://huggingface.co/facebook/m2m100_418M
# plus the 'AUTOMATIC' option where we will use a language detector.
language_codes = {
    'AUTOMATIC': 'auto',
    'Afrikaans (af)': 'af',
    'Albanian (sq)': 'sq',
    'Amharic (am)': 'am',
    'Arabic (ar)': 'ar',
    'Armenian (hy)': 'hy',
    'Asturian (ast)': 'ast',
    'Azerbaijani (az)': 'az',
    'Bashkir (ba)': 'ba',
    'Belarusian (be)': 'be',
    'Bengali (bn)': 'bn',
    'Bosnian (bs)': 'bs',
    'Breton (br)': 'br',
    'Bulgarian (bg)': 'bg',
    'Burmese (my)': 'my',
    'Catalan; Valencian (ca)': 'ca',
    'Cebuano (ceb)': 'ceb',
    'Central Khmer (km)': 'km',
    'Chinese (zh)': 'zh',
    'Croatian (hr)': 'hr',
    'Czech (cs)': 'cs',
    'Danish (da)': 'da',
    'Dutch; Flemish (nl)': 'nl',
    'English (en)': 'en',
    'Estonian (et)': 'et',
    'Finnish (fi)': 'fi',
    'French (fr)': 'fr',
    'Fulah (ff)': 'ff',
    'Gaelic; Scottish Gaelic (gd)': 'gd',
    'Galician (gl)': 'gl',
    'Ganda (lg)': 'lg',
    'Georgian (ka)': 'ka',
    'German (de)': 'de',
    'Greeek (el)': 'el',
    'Gujarati (gu)': 'gu',
    'Haitian; Haitian Creole (ht)': 'ht',
    'Hausa (ha)': 'ha',
    'Hebrew (he)': 'he',
    'Hindi (hi)': 'hi',
    'Hungarian (hu)': 'hu',
    'Icelandic (is)': 'is',
    'Igbo (ig)': 'ig',
    'Iloko (ilo)': 'ilo',
    'Indonesian (id)': 'id',
    'Irish (ga)': 'ga',
    'Italian (it)': 'it',
    'Japanese (ja)': 'ja',
    'Javanese (jv)': 'jv',
    'Kannada (kn)': 'kn',
    'Kazakh (kk)': 'kk',
    'Korean (ko)': 'ko',
    'Lao (lo)': 'lo',
    'Latvian (lv)': 'lv',
    'Lingala (ln)': 'ln',
    'Lithuanian (lt)': 'lt',
    'Luxembourgish; Letzeburgesch (lb)': 'lb',
    'Macedonian (mk)': 'mk',
    'Malagasy (mg)': 'mg',
    'Malay (ms)': 'ms',
    'Malayalam (ml)': 'ml',
    'Marathi (mr)': 'mr',
    'Mongolian (mn)': 'mn',
    'Nepali (ne)': 'ne',
    'Northern Sotho (ns)': 'ns',
    'Norwegian (no)': 'no',
    'Occitan (post 1500) (oc)': 'oc',
    'Oriya (or)': 'or',
    'Panjabi; Punjabi (pa)': 'pa',
    'Persian (fa)': 'fa',
    'Polish (pl)': 'pl',
    'Portuguese (pt)': 'pt',
    'Pushto; Pashto (ps)': 'ps',
    'Romanian; Moldavian; Moldovan (ro)': 'ro',
    'Russian (ru)': 'ru',
    'Serbian (sr)': 'sr',
    'Sindhi (sd)': 'sd',
    'Sinhala; Sinhalese (si)': 'si',
    'Slovak (sk)': 'sk',
    'Slovenian (sl)': 'sl',
    'Somali (so)': 'so',
    'Spanish (es)': 'es',
    'Sundanese (su)': 'su',
    'Swahili (sw)': 'sw',
    'Swati (ss)': 'ss',
    'Swedish (sv)': 'sv',
    'Tagalog (tl)': 'tl',
    'Tamil (ta)': 'ta',
    'Thai (th)': 'th',
    'Tswana (tn)': 'tn',
    'Turkish (tr)': 'tr',
    'Ukrainian (uk)': 'uk',
    'Urdu (ur)': 'ur',
    'Uzbek (uz)': 'uz',
    'Vietnamese (vi)': 'vi',
    'Welsh (cy)': 'cy',
    'Western Frisian (fy)': 'fy',
    'Wolof (wo)': 'wo',
    'Xhosa (xh)': 'xh',
    'Yiddish (yi)': 'yi',
    'Yoruba (yo)': 'yo',
    'Zulu (zu)': 'zu'
}

tgt_language_codes = {
    'English (en)': 'en',
    'French (fr)': 'fr'
}


def build_text_chunks(
        text: str,
        sents_per_chunk: int=5,
        words_per_chunk=200) -> list[str]:
    """Split a given text into chunks with at most sents_per_chnks and words_per_chunk

    Given a text:
        - Split the text into sentences.
        - Build text chunks:
            - Consider up to sents_per_chunk
            - Ensure that we do not exceed words_per_chunk
    """
    # Split text into sentences...
    sentences = [
        sent.text.strip() for sent in model_spacy(text).sents if sent.text.strip()
    ]
    logger.info(f"TEXT: {text[:25]}, NB_SENTS: {len(sentences)}")
    
    # Create text chunks of N sentences
    chunks = []
    chunk = ''
    chunk_nb_sentences = 0
    chunk_nb_words = 0

    for i in range(0, len(sentences)):
        # Get sentence
        sent = sentences[i]
        sent_nb_words = len(sent.split())

        # If chunk already 'full', save chunk, start new chunk
        if (
                (chunk_nb_words + sent_nb_words > words_per_chunk) or
                (chunk_nb_sentences + 1 > sents_per_chunk)
           ):
            chunks.append(chunk)
            chunk = ''
            chunk_nb_sentences = 0
            chunk_nb_words = 0
        
        # Append sentence to current chunk. One sentence per line.
        chunk = (chunk + '\n' + sent) if chunk else sent
        chunk_nb_sentences += 1
        chunk_nb_words += sent_nb_words

    # Append last chunk
    if chunk:
        chunks.append(chunk)

    return chunks


class Singleton(type):
    _instances = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]


class ModelM2M100(metaclass=Singleton):
    """Loads an instance of the M2M100 model.

    Model: https://huggingface.co/facebook/m2m100_1.2B
    """
    def __init__(self):
        self._model_name = "facebook/m2m100_418M"
        self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name)
        self._model = M2M100ForConditionalGeneration.from_pretrained(
            self._model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True
            #quantization_config=quantization_config
        )
        self._model = torch.compile(self._model)
    
    @spaces.GPU
    def translate(
            self,
            text: str,
            src_lang: str,
            tgt_lang: str,
            chunk_text: bool=True,
            sents_per_chunk: int=5,
            words_per_chunk: int=200
        ) -> str:
        """Translate the given text from src_lang to tgt_lang.

        The text will be split into chunks to ensure the chunks fit into the 
        model input_max_length (usually 512 tokens).
        """
        chunks = [text,]
        if chunk_text:
            chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk)
        
        self._tokenizer.src_lang = src_lang

        translated_chunks = []
        for chunk in chunks:
            input_ids = self._tokenizer(
                chunk,
                return_tensors="pt").input_ids.to(self._model.device)
            outputs = self._model.generate(
                input_ids=input_ids,
                forced_bos_token_id=self._tokenizer.get_lang_id(tgt_lang))
            translated_chunk = self._tokenizer.batch_decode(
                outputs,
                skip_special_tokens=True)[0]
            translated_chunks.append(translated_chunk)

        return '\n'.join(translated_chunks)

    @property
    def model_name(self):
        return self._model_name

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def model(self):
        return self._model

    @property
    def device(self):
        return self._model.device


class ModelMADLAD(metaclass=Singleton):
    """Loads an instance of the Google MADLAD model (3B).

    Model: https://huggingface.co/google/madlad400-3b-mt
    """
    def __init__(self):
        self._model_name = "google/madlad400-3b-mt"
        self._input_max_length = 512 # config.json n_positions
        self._output_max_length = 512 # config.json n_positions
        self._tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, use_fast=True
        )
        self._model = AutoModelForSeq2SeqLM.from_pretrained(
            self._model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            quantization_config=quantization_config
        )
        self._model = torch.compile(self._model)
    
    @spaces.GPU
    def translate(
            self,
            text: str,
            tgt_lang: str,
            chunk_text: True,
            sents_per_chunk: int=5,
            words_per_chunk: int=5
        ) -> str:
        """Translate given text into the target language.

        The text will be split into chunks to ensure the chunks fit into the 
        model input_max_length (usually 512 tokens).
        """
        chunks = [text,]
        if chunk_text:
            chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk)
        
        translated_chunks = []
        for chunk in chunks:
            input_text = f"<2{tgt_lang}> {chunk}"
            logger.info(f" Translating: {input_text[:50]}")
            input_ids = self._tokenizer(
                input_text,
                return_tensors="pt",
                max_length=self._input_max_length,
                truncation=True,
                padding="longest").input_ids.to(self._model.device)
            outputs = self._model.generate(
                input_ids=input_ids,
                max_length=self._output_max_length)
            translated_chunk = self._tokenizer.decode(
                outputs[0],
                skip_special_tokens=True)
            translated_chunks.append(translated_chunk)
    
        return '\n'.join(translated_chunks)

    @property
    def model_name(self):
        return self._model_name
    
    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def model(self):
        return self._model
    
    @property
    def device(self):
        return self._model.device


# Bi-lingual individual models
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"])
model_names = {
    "ar": "Helsinki-NLP/opus-mt-ar-en",
    "en": "Helsinki-NLP/opus-mt-en-fr",
    "fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc",
    "fr": "Helsinki-NLP/opus-mt-fr-en",
    "he": "Helsinki-NLP/opus-mt-tc-big-he-en",
    "zh": "Helsinki-NLP/opus-mt-zh-en",
}

# Registry for all loaded bilingual models
tokenizer_model_registry = {}

device = 'cpu'

def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
    """
    Return the (tokenizer, model) for a given source language.
    """
    src_lang = src_lang.lower()

    # Already loaded?
    if src_lang in tokenizer_model_registry:
        return tokenizer_model_registry.get(src_lang)

    # Load tokenizer and model
    model_name = model_names.get(src_lang)
    if not model_name:
        raise Exception(f"No model defined for language: {src_lang}")
    
    # We will leave the models on the CPU (for now)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    if model.config.torch_dtype != torch.float16:
        model = model.half()
    model.to(device)
    tokenizer_model_registry[src_lang] = (tokenizer, model)

    return (tokenizer, model)

# Max number of words for given input text
# - Usually 512 tokens (max position encodings, as well as max length)
# - Let's set to some number of words somewhat lower than that threshold
# - e.g. 200 words
max_words_per_chunk = 200