import os
import time
import warnings
from pathlib import Path

import gradio as gr
import librosa
import spaces
import torch
from loguru import logger
from transformers import pipeline

warnings.filterwarnings("ignore")

is_hf = os.getenv("SYSTEM") == "spaces"

generate_kwargs = {
    "language": "Japanese",
    "do_sample": False,
    "num_beams": 1,
    "no_repeat_ngram_size": 5,
    "max_new_tokens": 64,
}

model_dict = {
    "whisper-large-v3-turbo": "openai/whisper-large-v3-turbo",
    "kotoba-whisper-v2.0": "kotoba-tech/kotoba-whisper-v2.0",
    "anime-whisper": "litagin/anime-whisper",
}

logger.info("Initializing pipelines...")
pipe_dict = {
    k: pipeline(
        "automatic-speech-recognition",
        model=v,
        device="cuda" if torch.cuda.is_available() else "cpu",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    )
    for k, v in model_dict.items()
}
logger.success("Pipelines initialized!")


@spaces.GPU
def transcribe_common(audio: str, model: str) -> str:
    if not audio:
        return "No audio file"
    filename = Path(audio).name
    logger.info(f"Model: {model}")
    logger.info(f"Audio: {filename}")
    
    try:
        y, sr = librosa.load(audio, mono=True, sr=16000)
    except Exception as e:
        from pydub import AudioSegment
        audio = AudioSegment.from_file(audio)
        audio.export("temp.wav", format="wav")
        y, sr = librosa.load("temp.wav", mono=True, sr=16000)
        Path("temp.wav").unlink()

    duration = librosa.get_duration(y=y, sr=sr)
    logger.info(f"Duration: {duration:.2f}s")
    kwargs = generate_kwargs.copy()
    if duration > 30:
        kwargs["return_timestamps"] = True

    start_time = time.time()
    result = pipe_dict[model](y, generate_kwargs=kwargs)["text"]
    end_time = time.time()
    logger.success(f"Finished in {end_time - start_time:.2f}s\n{result}")
    
    # Guardar resultado en un archivo .str
    output_path = f"{Path(filename).stem}_{model}.str"
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(result)
    
    logger.info(f"Transcription saved to {output_path}")
    return output_path  # Devuelve el path del archivo transcrito


def transcribe_others(audio) -> tuple[str, str]:
    result_v3 = transcribe_common(audio, "whisper-large-v3-turbo")
    result_kotoba_v2 = transcribe_common(audio, "kotoba-whisper-v2.0")
    return result_v3, result_kotoba_v2


def transcribe_anime_whisper(audio) -> str:
    return transcribe_common(audio, "anime-whisper")


initial_md = """
# Anime-Whisper Demo
[**Anime Whisper**](https://huggingface.co/litagin/anime-whisper): 5千時間以上のアニメ調セリフと台本でファインチューニングされた日本語音声認識モデルのデモです。句読点や感嘆符がリズムや感情に合わせて自然に付き、NSFW含む非言語発言もうまく台本調に書き起こされます。
- デモでは**音声は15秒まで**しか受け付けません
- 日本語のみ対応 (Japanese only)
- 比較のために [openai/whisper-large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) と [kotoba-tech/kotoba-whisper-v2.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.0) も用意しています
pipeに渡しているkwargsは以下:
```python
generate_kwargs = {
    "language": "Japanese",
    "do_sample": False,
    "num_beams": 1,
    "no_repeat_ngram_size": 5,
    "max_new_tokens": 64,  # 結果が長いときは途中で打ち切られる
}
```
"""

with gr.Blocks() as app:
    gr.Markdown(initial_md)
    audio = gr.Audio(type="filepath")
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Anime-Whisper")
            button_galgame = gr.Button("Transcribe with Anime-Whisper")
            output_galgame = gr.File(label="Download Anime-Whisper Transcription")
    gr.Markdown("### Comparison")
    button_others = gr.Button("Transcribe with other models")
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Whisper-Large-V3-Turbo")
            output_v3 = gr.File(label="Download Whisper-Large-V3-Turbo Transcription")
        with gr.Column():
            gr.Markdown("### Kotoba-Whisper-V2.0")
            output_kotoba_v2 = gr.File(label="Download Kotoba-Whisper-V2.0 Transcription")

    button_galgame.click(
        transcribe_anime_whisper,
        inputs=[audio],
        outputs=[output_galgame],
    )
    button_others.click(
        transcribe_others,
        inputs=[audio],
        outputs=[output_v3, output_kotoba_v2],
    )

app.launch(inbrowser=True)