import sys
import time

from importlib.metadata import version

from os import remove
from os.path import exists

import numpy as np

import torch
import torchaudio
import torchaudio.transforms as T

import streamlit as st

from streamlit.runtime.uploaded_file_manager import UploadedFile
from transformers import HubertForCTC, Wav2Vec2Processor


# Config
model_name = "Yehor/hubert-uk"

torchaudio_backend = "soundfile"

min_duration = 0.5
max_duration = 60

concurrency_limit = 5
use_torch_compile = False

# Torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load the model
asr_model = HubertForCTC.from_pretrained(
    model_name, torch_dtype=torch_dtype, device_map=device
)
processor = Wav2Vec2Processor.from_pretrained(model_name)

if use_torch_compile:
    asr_model = torch.compile(asr_model)

# Elements
examples = [
    "example_1.wav",
    "example_2.wav",
    "example_3.wav",
    "example_4.wav",
    "example_5.wav",
    "example_6.wav",
]

examples_table = """
| File  | Text |
| ------------- | ------------- |
| `example_1.wav`  | тема про яку не люблять говорити офіційні джерела у генштабі і міноборони це хімічна зброя окупанти вже тривалий час використовують хімічну зброю заборонену |
| `example_2.wav`  | всіма конвенціями якщо спочатку це були гранати з дронів то тепер фіксують випадки застосування |
| `example_3.wav`  | хімічних снарядів причому склад отруйної речовони різний а отже й наслідки для наших військових теж різні  |
| `example_4.wav`  | використовує на фронті все що має і хімічна зброя не вийняток тож з чим маємо справу розбиралася марія моганисян |
| `example_5.wav`  | двох тисяч випадків застосування росіянами боєприпасів споряджених небезпечними хімічними речовинами |
| `example_6.wav`  | на всі писані норми марія моганисян олександр моторний спецкор марафон єдині новини |
""".strip()

authors_table = """
## Authors

Follow them in social networks and **contact** if you need any help or have any questions:

| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> <br> **Yehor Smoliakov** |
|------------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram                                                                       |
| https://x.com/yehor_smoliakov at X                                                                   |
| https://github.com/egorsmkv at GitHub                                                                |
| https://huggingface.co/Yehor at Hugging Face                                                         |
| or use egorsmkv@gmail.com                                                                            |
""".strip()

description_head = f"""
## Overview

This space uses https://huggingface.co/Yehor/hubert-uk model to recognize audio files.

> Due to resource limitations, audio duration **must not** exceed **{max_duration}** seconds.
""".strip()

description_foot = f"""
## Community

- **Discord**: https://discord.gg/yVAjkBgmt4
- Speech Recognition: https://t.me/speech_recognition_uk
- Speech Synthesis: https://t.me/speech_synthesis_uk

## More

Check out other ASR models: https://github.com/egorsmkv/speech-recognition-uk

{authors_table}
""".strip()

transcription_value = """
Recognized text will appear here.

Choose **an example file** below the Recognize button, upload **your audio file**, or use **the microphone** to record own voice.
""".strip()

tech_env = f"""
#### Environment

- Python: {sys.version}
- Torch device: {device}
- Torch dtype: {torch_dtype}
- Use torch.compile: {use_torch_compile}
""".strip()

tech_libraries = f"""
#### Libraries

- torch: {version('torch')}
- torchaudio: {version('torchaudio')}
- transformers: {version('transformers')}
- accelerate: {version('accelerate')}
- streamlit: {version('streamlit')}
""".strip()


# UploadedFile
def inference(uploaded_file: UploadedFile):
    audio_path = uploaded_file.file_id + '.wav'

    with open(audio_path, 'wb') as f:
        f.write(uploaded_file.getvalue())

    if not audio_path:
        st.error("Please upload an audio file.")
        return

    st.info("Starting recognition")

    meta = torchaudio.info(audio_path, backend=torchaudio_backend)
    duration = meta.num_frames / meta.sample_rate

    if duration < min_duration:
        st.error(
            f"The duration of the file is less than {min_duration} seconds, it is {round(duration, 2)} seconds."
        )
        return
    if duration > max_duration:
        st.error(f"The duration of the file exceeds {max_duration} seconds.")
        return

    paths = [
        audio_path,
    ]

    results = []

    for path in paths:
        t0 = time.time()

        meta = torchaudio.info(audio_path, backend=torchaudio_backend)
        audio_duration = meta.num_frames / meta.sample_rate

        audio_input, sr = torchaudio.load(path, backend=torchaudio_backend)

        if meta.num_channels > 1:
            audio_input = torch.mean(audio_input, dim=0, keepdim=True)

        if meta.sample_rate != 16_000:
            resampler = T.Resample(sr, 16_000, dtype=audio_input.dtype)
            audio_input = resampler(audio_input)

        audio_input = audio_input.squeeze(0).numpy()

        inputs = processor(
            [audio_input], sampling_rate=16_000, padding=True
        ).input_values
        features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)

        with torch.inference_mode():
            logits = asr_model(features).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        predictions = processor.batch_decode(predicted_ids)

        if not predictions:
            predictions = "-"

        elapsed_time = round(time.time() - t0, 2)
        rtf = round(elapsed_time / audio_duration, 4)
        audio_duration = round(audio_duration, 2)

        results.append(
            {
                "path": path.split("/")[-1],
                "transcription": "\n".join(predictions),
                "audio_duration": audio_duration,
                "rtf": rtf,
            }
        )

    st.info("Finished!")

    result_texts = []

    for result in results:
        result_texts.append(f'**{result["path"]}**')
        result_texts.append("\n\n")
        result_texts.append(f'> {result["transcription"]}')
        result_texts.append("\n\n")
        result_texts.append(f'**Audio duration**: {result["audio_duration"]}')
        result_texts.append("\n")
        result_texts.append(f'**Real-Time Factor**: {result["rtf"]}')

    if exists(audio_path):
        remove(audio_path)

    return "\n".join(result_texts)


st.title("Speech-to-Text for Ukrainian using HuBERT")
st.markdown(description_head)

st.markdown("## Usage")

audio_file = st.file_uploader("Upload an audio file", type=["wav"])

if st.button("Recognize"):
    if audio_file is not None:
        transcription = inference(audio_file)
        st.markdown(transcription)
    else:
        st.error("Please upload an audio file.")

st.markdown("### Examples")
st.markdown(examples_table)

st.markdown(description_foot, unsafe_allow_html=True)

st.markdown("### Environment")
st.markdown(tech_env)
st.markdown(tech_libraries)