File size: 3,070 Bytes
d93a410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb47424
 
 
a30dc65
bb47424
9bc7b85
bb47424
 
d93a410
 
 
 
 
 
 
 
 
 
6458569
d93a410
 
 
 
 
 
6458569
d93a410
 
 
 
 
0d17251
6458569
d93a410
 
2f57ed2
 
 
 
 
 
 
 
 
 
0d17251
d93a410
a30dc65
d93a410
 
 
a30dc65
d93a410
 
 
 
2f57ed2
d93a410
 
 
 
 
2a1faee
7571349
d93a410
 
 
2f57ed2
d93a410
 
 
f3f5631
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) 2022 Horizon Robotics. (authors: Binbin Zhang)
#               2022 Chengdong Liang ([email protected])
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gradio as gr
import torch
from wenet.cli.model import load_model



def process_cat_embs(cat_embs):
    device = "cpu"
    cat_embs = torch.tensor(
        [float(c) for c in cat_embs.split(',')]).to(device)
    return cat_embs


def download_rev_models():
    from huggingface_hub import hf_hub_download
    import joblib

    REPO_ID = "Revai/reverb-asr"

    files = ['reverb_asr_v1.jit.zip', 'tk.units.txt']
    downloaded_files = [hf_hub_download(repo_id=REPO_ID, filename=f) for f in files]
    model = load_model(downloaded_files[0], downloaded_files[1])
    return model

model = download_rev_models()
    

def recognition(audio, style=0):
    if audio is None:
        return "Input Error! Please enter one audio!"
    # NOTE: model supports 16k sample_rate

    cat_embs = ','.join([str(s) for s in (style, 1-style)])
    cat_embs = process_cat_embs(cat_embs)
    ans = model.transcribe(audio, cat_embs = cat_embs)

    if ans is None:
        return "ERROR! No text output! Please try again!"
    txt = ans['text']
    txt = txt.replace('▁', ' ')
    return txt


# input
inputs = [
    gr.inputs.Audio(source="microphone", type="filepath", label='Input audio'),
    gr.Slider(0, 1, value=0, label="Verbatimicity - from non-verbatim (0) to verbatim (1)", info="Choose a transcription style between non-verbatim and verbatim"),
]

examples = [
    ['examples/POD1000000012_S0000335.wav'],
    ['examples/POD1000000013_S0000062.wav'],
    ['examples/POD1000000032_S0000020.wav'], 
    ['examples/POD1000000032_S0000038.wav'],
    ['examples/POD1000000032_S0000050.wav'],
    ['examples/POD1000000032_S0000058.wav'],
]


output = gr.outputs.Textbox(label="Output Text")

text = "Reverb ASR Transcription Styles Demo"

# description
description = (
    "Reverb ASR supports verbatim and non-verbatim transcription. Try recording an audio with disfluencies (ex: \'uh\', \'um\') and testing both transcription styles. Or, choose an example audio below."  # noqa
)

article = (
    "<p style='text-align: center'>"
    "<a href='https://rev.com' target='_blank'>Learn more about Rev</a>"  # noqa
    "</p>")

interface = gr.Interface(
    fn=recognition,
    inputs=inputs,
    outputs=output,
    theme="Nymbo/Nymbo_Theme",
    title=text,
    description=description,
    article=article,
    examples=examples,
    theme='huggingface',
)

interface.launch(enable_queue=True)