File size: 5,162 Bytes
ff71374
b399068
 
8bbb796
 
b399068
 
 
 
8bbb796
ff71374
b399068
8bbb796
5e6ab01
b399068
 
 
 
 
8bbb796
 
 
b399068
 
 
5e6ab01
 
 
 
 
 
 
 
b399068
 
 
5e6ab01
 
 
b399068
 
 
 
 
 
 
5e6ab01
 
b399068
 
 
 
 
c318bd7
b399068
 
 
 
 
5e6ab01
 
b399068
 
 
5e6ab01
b399068
 
5e6ab01
b399068
 
5e6ab01
b399068
 
8bbb796
b399068
 
 
5d20e7d
b399068
 
 
 
 
8bbb796
 
 
 
 
 
 
 
b399068
 
 
 
 
8bbb796
 
b399068
 
 
 
 
5e6ab01
 
 
 
 
 
 
 
 
 
 
ff71374
5e6ab01
 
 
 
 
b399068
 
 
8bbb796
 
b669864
 
 
 
 
8bbb796
 
 
 
b399068
 
 
 
 
 
 
 
8bbb796
b399068
8bbb796
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
from pathlib import Path
from typing import Tuple
import gradio as gr
from transformers import pipeline, Pipeline
from huggingface_hub import repo_exists


from speech_to_text_finetune.config import LANGUAGES_NAME_TO_ID

is_hf_space = os.getenv("IS_HF_SPACE")
languages = LANGUAGES_NAME_TO_ID.keys()
model_ids = [
    "",
    "openai/whisper-tiny",
    "openai/whisper-small",
    "openai/whisper-medium",
    "openai/whisper-large-v3",
    "openai/whisper-large-v3-turbo",
]


def _load_local_model(model_dir: str, language: str) -> Tuple[Pipeline | None, str]:
    if not Path(model_dir).is_dir():
        return None, f"⚠️ Couldn't find local model directory: {model_dir}"
    from transformers import (
        WhisperProcessor,
        WhisperTokenizer,
        WhisperFeatureExtractor,
        WhisperForConditionalGeneration,
    )

    processor = WhisperProcessor.from_pretrained(model_dir)
    tokenizer = WhisperTokenizer.from_pretrained(
        model_dir, language=language, task="transcribe"
    )
    feature_extractor = WhisperFeatureExtractor.from_pretrained(model_dir)
    model = WhisperForConditionalGeneration.from_pretrained(model_dir)

    return pipeline(
        task="automatic-speech-recognition",
        model=model,
        processor=processor,
        tokenizer=tokenizer,
        feature_extractor=feature_extractor,
    ), f"✅ Local model has been loaded from {model_dir}."


def _load_hf_model(model_repo_id: str, language: str) -> Tuple[Pipeline | None, str]:
    if not repo_exists(model_repo_id):
        return (
            None,
            f"⚠️ Couldn't find {model_repo_id} on Hugging Face. If its a private repo, make sure you are logged in locally.",
        )
    return pipeline(
        "automatic-speech-recognition",
        model=model_repo_id,
        generate_kwargs={"language": language},
    ), f"✅ HF Model {model_repo_id} has been loaded."


def load_model(
    language: str, dropdown_model_id: str, hf_model_id: str, local_model_id: str
) -> Tuple[Pipeline, str]:
    if dropdown_model_id and not hf_model_id and not local_model_id:
        yield None, f"Loading {dropdown_model_id}..."
        yield _load_hf_model(dropdown_model_id, language)
    elif hf_model_id and not local_model_id and not dropdown_model_id:
        yield None, f"Loading {hf_model_id}..."
        yield _load_hf_model(hf_model_id, language)
    elif local_model_id and not hf_model_id and not dropdown_model_id:
        yield None, f"Loading {local_model_id}..."
        yield _load_local_model(local_model_id, language)
    else:
        yield (
            None,
            "️️⚠️ Please select or fill at least and only one of the options above",
        )
    if not language:
        yield None, "⚠️ Please select a language from the dropdown"


def transcribe(pipe: Pipeline, audio: gr.Audio) -> str:
    text = pipe(audio)["text"]
    return text


def setup_gradio_demo():
    with gr.Blocks() as demo:
        gr.Markdown(
            """ # 🗣️ Speech-to-Text Transcription
            ### 1. Select a language from the dropdown menu.
            ### 2. Select which model to load from one of the options below.
            ### 3. Load the model by clicking the Load model button.
            ### 4. Record a message or upload an audio file.
            ### 5. Click Transcribe to see the transcription generated by the model.
            """
        )
        ### Language & Model selection ###

        selected_lang = gr.Dropdown(
            choices=list(languages), value=None, label="Select a language"
        )

        with gr.Row():
            with gr.Column():
                dropdown_model = gr.Dropdown(
                    choices=model_ids, label="Option 1: Select a model"
                )
            with gr.Column():
                user_model = gr.Textbox(
                    label="Option 2: Paste HF model id",
                    placeholder="my-username/my-whisper-tiny",
                )
            with gr.Column(visible=not is_hf_space):
                local_model = gr.Textbox(
                    label="Option 3: Paste local path to model directory",
                    placeholder="artifacts/my-whisper-tiny",
                )

        load_model_button = gr.Button("Load model")
        model_loaded = gr.Markdown()

        ### Transcription ###
        audio_input = gr.Audio(
            sources=["microphone", "upload"],
            type="filepath",
            label="Record a message / Upload audio file",
            show_download_button=True,
            max_length=30,
        )
        transcribe_button = gr.Button("Transcribe")
        transcribe_output = gr.Text(label="Output")

        ### Event listeners ###
        model = gr.State()
        load_model_button.click(
            fn=load_model,
            inputs=[selected_lang, dropdown_model, user_model, local_model],
            outputs=[model, model_loaded],
        )

        transcribe_button.click(
            fn=transcribe, inputs=[model, audio_input], outputs=transcribe_output
        )

    demo.launch()


if __name__ == "__main__":
    setup_gradio_demo()