"""
A Gradio app to transcribe and diarize a podcast using Whisper and pyannote. Adapted from Dwarkesh Patel's Colab notebook here:
https://colab.research.google.com/drive/1V-Bt5Hm2kjaDb4P1RyMSswsDKyrzc2-3?usp=sharing
"""
import whisper
import datetime

import subprocess
import torch
import gradio as gr
import pyannote.audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.audio import Audio
from pyannote.core import Segment
import wave
import contextlib

from sklearn.cluster import AgglomerativeClustering
import numpy as np

if torch.cuda.is_available():
    device_type = "cuda"
elif torch.backends.mps.is_available():
    device_type = "mps"
else:
    device_type = "cpu"

print(f"chosen device: {device_type}")


embedding_model = PretrainedSpeakerEmbedding(
    "speechbrain/spkrec-ecapa-voxceleb", device=torch.device(device_type)
)

audio = Audio()


def time(secs):
    return datetime.timedelta(seconds=round(secs))


def segment_embedding(segment, duration, audio, path):
    start = segment["start"]
    # Whisper overshoots the end timestamp in the last segment
    end = min(duration, segment["end"])
    clip = Segment(start, end)
    waveform, sample_rate = audio.crop(path, clip)
    return embedding_model(waveform[None])


def get_whisper_results(path, model_type):
    model = whisper.load_model(model_type)
    result = model.transcribe(path)
    segments = result["segments"]

    with contextlib.closing(wave.open(path, "r")) as f:
        frames = f.getnframes()
        rate = f.getframerate()
        duration = frames / float(rate)

    return result, segments, frames, rate, duration


def cluster_embeddings(segments, duration, path, num_speakers):
    embeddings = np.zeros(shape=(len(segments), 192))
    for i, segment in enumerate(segments):
        embeddings[i] = segment_embedding(segment, duration, audio, path)

    embeddings = np.nan_to_num(embeddings)

    clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
    labels = clustering.labels_
    for i in range(len(segments)):
        segments[i]["speaker"] = "SPEAKER " + str(labels[i] + 1)


def transcribe(path, model_type, num_speakers):
    if path[-3:] != "wav":
        subprocess.call(["ffmpeg", "-i", path, "audio.wav", "-y"])
        path = "audio.wav"

    ret = ""
    print("running whisper...")
    result, segments, frames, rate, duration = get_whisper_results(path, model_type)
    print("done running whisper. Clustering embeddings...")
    cluster_embeddings(segments, duration, path, num_speakers)
    print(f"done clustering embeddings. Time to return...")

    for i, segment in enumerate(segments):
        if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
            ret += "\n" + segment["speaker"] + " " + str(time(segment["start"])) + "\n"
        ret += segment["text"][1:] + " "

    return ret


if __name__ == "__main__":
    interface = gr.Interface(
        fn=transcribe,
        inputs=[
            gr.File(file_count="single", label="Upload an audio file"),
            gr.Radio(
                choices=["tiny", "base", "small", "medium", "large-v3"],
                value="large-v3",
                type="value",
                label="Model size",
            ),
            gr.Number(
                value=2,
                label="Number of speakers",
            ),
        ],
        outputs=gr.Textbox(label="Transcript", show_copy_button=True),
        title="Transcribe a podcast!",
        description="Upload an audio file and choose a model size and number of speakers on the left, then click submit to transcribe!",
        theme=gr.themes.Soft(),
    )
    interface.launch(share=True)