import gradio as gr
import torch
from transformers import pipeline
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
import fasttext
from huggingface_hub import hf_hub_download


model_path = hf_hub_download(repo_id="cis-lmu/glotlid", filename="model.bin")
identification_model = fasttext.load_model(model_path)
def lang_ident(text):
    return indetification_model.predict(text)

pretrained_model: str = "facebook/m2m100_1.2B"
cache_dir: str = "models/"

tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
translation_model = M2M100ForConditionalGeneration.from_pretrained(
        pretrained_model, cache_dir=cache_dir)

transcription = pipeline("automatic-speech-recognition", model= "openai/whisper-base")
clasification = pipeline(
    "audio-classification",
    model="anton-l/xtreme_s_xlsr_300m_minds14",
)

def audio_a_text(audio):
  text = transcription(audio)["text"]
  return text

def text_to_sentimient(audio):
    #text = transcription(audio)["text"]
    return clasification(audio)

lang_id = {
    "Afrikaans": "af",
    "Amharic": "am",
    "Arabic": "ar",
    "Asturian": "ast",
    "Azerbaijani": "az",
    "Bashkir": "ba",
    "Belarusian": "be",
    "Bulgarian": "bg",
    "Bengali": "bn",
    "Breton": "br",
    "Bosnian": "bs",
    "Catalan": "ca",
    "Cebuano": "ceb",
    "Czech": "cs",
    "Welsh": "cy",
    "Danish": "da",
    "German": "de",
    "Greeek": "el",
    "English": "en",
    "Spanish": "es",
    "Estonian": "et",
    "Persian": "fa",
    "Fulah": "ff",
    "Finnish": "fi",
    "French": "fr",
    "Western Frisian": "fy",
    "Irish": "ga",
    "Gaelic": "gd",
    "Galician": "gl",
    "Gujarati": "gu",
    "Hausa": "ha",
    "Hebrew": "he",
    "Hindi": "hi",
    "Croatian": "hr",
    "Haitian": "ht",
    "Hungarian": "hu",
    "Armenian": "hy",
    "Indonesian": "id",
    "Igbo": "ig",
    "Iloko": "ilo",
    "Icelandic": "is",
    "Italian": "it",
    "Japanese": "ja",
    "Javanese": "jv",
    "Georgian": "ka",
    "Kazakh": "kk",
    "Central Khmer": "km",
    "Kannada": "kn",
    "Korean": "ko",
    "Luxembourgish": "lb",
    "Ganda": "lg",
    "Lingala": "ln",
    "Lao": "lo",
    "Lithuanian": "lt",
    "Latvian": "lv",
    "Malagasy": "mg",
    "Macedonian": "mk",
    "Malayalam": "ml",
    "Mongolian": "mn",
    "Marathi": "mr",
    "Malay": "ms",
    "Burmese": "my",
    "Nepali": "ne",
    "Dutch": "nl",
    "Norwegian": "no",
    "Northern Sotho": "ns",
    "Occitan": "oc",
    "Oriya": "or",
    "Panjabi": "pa",
    "Polish": "pl",
    "Pushto": "ps",
    "Portuguese": "pt",
    "Romanian": "ro",
    "Russian": "ru",
    "Sindhi": "sd",
    "Sinhala": "si",
    "Slovak": "sk",
    "Slovenian": "sl",
    "Somali": "so",
    "Albanian": "sq",
    "Serbian": "sr",
    "Swati": "ss",
    "Sundanese": "su",
    "Swedish": "sv",
    "Swahili": "sw",
    "Tamil": "ta",
    "Thai": "th",
    "Tagalog": "tl",
    "Tswana": "tn",
    "Turkish": "tr",
    "Ukrainian": "uk",
    "Urdu": "ur",
    "Uzbek": "uz",
    "Vietnamese": "vi",
    "Wolof": "wo",
    "Xhosa": "xh",
    "Yiddish": "yi",
    "Yoruba": "yo",
    "Chinese": "zh",
    "Zulu": "zu",
}
def translation_text(source_lang, target_lang, user_input):
    src_lang = lang_id[source_lang]
    trg_lang = lang_id[target_lang]
    tokenizer.src_lang = src_lang
    with torch.no_grad():

        encoded_input = tokenizer(user_input, return_tensors="pt")

        generated_tokens = translation_model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang))
        translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
    return translated_text

def print_s(source_lang, target_lang, text0):
    print(source_lang)
    return lang_id[source_lang], lang_id[target_lang], text0
demo = gr.Blocks()

with demo:
  gr.Markdown("Speech analyzer")
  audio = gr.Audio(type="filepath", label = "Upload a file")
  text0 = gr.Textbox()
  text = gr.Textbox()
  source_lang = gr.Dropdown(label="Source lang", choices=list(lang_id.keys()), value=list(lang_id.keys())[0])
  target_lang = gr.Dropdown(label="target lang", choices=list(lang_id.keys()), value=list(lang_id.keys())[0])
  
  #gr.Examples(examples = list(lang_id.keys()),
   #             inputs=[
    #                source_lang])
  b1 = gr.Button("convert to text")
  b3 = gr.Button("translate")
  b3.click(translation_text, inputs = [source_lang, target_lang, text0], outputs = text)
  b1.click(audio_a_text, inputs=audio, outputs=text)

  b2 = gr.Button("Classification of language")
  b2.click(lang_ident,inputs  = text0, outputs=text)

demo.launch()