Noumida commited on
Commit
d95af38
·
verified ·
1 Parent(s): eba970d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
  import torch
5
  import torchaudio
6
  import spaces
7
- import nemo.collections.asr as nemo_asr
8
 
9
  LANGUAGE_NAME_TO_CODE = {
10
  "Assamese": "as", "Bengali": "bn", "Bodo": "br", "Dogri": "doi",
@@ -15,30 +15,43 @@ LANGUAGE_NAME_TO_CODE = {
15
  "Telugu": "te", "Urdu": "ur"
16
  }
17
 
18
- DESCRIPTION = """IndicConformer: Dual-Decoder ASR for Indian Languages"""
19
 
20
- device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
21
- model = nemo_asr.models.EncDecCTCModel.from_pretrained("ai4bharat/IndicConformer").to(device)
22
- model.eval()
 
 
 
 
 
 
 
23
 
24
  @spaces.GPU
25
  def transcribe_ctc_and_rnnt(audio_path, language_name):
26
  lang_id = LANGUAGE_NAME_TO_CODE[language_name]
27
- waveform, sample_rate = torchaudio.load(audio_path)
 
28
  waveform = waveform.mean(dim=0, keepdim=True) if waveform.shape[0] > 1 else waveform
29
- waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
30
- waveform_np = waveform.squeeze().numpy()
 
31
 
32
- model.cur_decoder = "ctc"
33
- ctc = model.transcribe([waveform_np], batch_size=1, language_id=lang_id)[0][0]
 
 
 
34
 
35
- model.cur_decoder = "rnnt"
36
- rnnt = model.transcribe([waveform_np], batch_size=1, language_id=lang_id)[0][0]
37
 
38
- return ctc, rnnt
39
 
 
40
  with gr.Blocks() as demo:
41
- gr.Markdown(DESCRIPTION)
42
  with gr.Row():
43
  with gr.Column():
44
  audio = gr.Audio(label="Upload or record audio", type="filepath")
 
4
  import torch
5
  import torchaudio
6
  import spaces
7
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoModelForCTC
8
 
9
  LANGUAGE_NAME_TO_CODE = {
10
  "Assamese": "as", "Bengali": "bn", "Bodo": "br", "Dogri": "doi",
 
15
  "Telugu": "te", "Urdu": "ur"
16
  }
17
 
18
+ DESCRIPTION = "IndicConformer-600M Multilingual ASR (CTC + RNNT)"
19
 
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ # Load processor and models
23
+ processor = AutoProcessor.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
24
+
25
+ model_ctc = AutoModelForCTC.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True).to(device)
26
+ model_ctc.eval()
27
+
28
+ model_rnnt = AutoModelForSpeechSeq2Seq.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True).to(device)
29
+ model_rnnt.eval()
30
 
31
  @spaces.GPU
32
  def transcribe_ctc_and_rnnt(audio_path, language_name):
33
  lang_id = LANGUAGE_NAME_TO_CODE[language_name]
34
+
35
+ waveform, sr = torchaudio.load(audio_path)
36
  waveform = waveform.mean(dim=0, keepdim=True) if waveform.shape[0] > 1 else waveform
37
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
38
+
39
+ input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_values.to(device)
40
 
41
+ with torch.no_grad():
42
+ # CTC decoding
43
+ ctc_logits = model_ctc(input_values).logits
44
+ ctc_ids = torch.argmax(ctc_logits, dim=-1)
45
+ ctc_output = processor.batch_decode(ctc_ids)[0]
46
 
47
+ # RNNT decoding
48
+ rnnt_output = processor.batch_decode(model_rnnt.generate(input_values, decoder_input_ids=torch.tensor([[processor.tokenizer.lang2id[lang_id]]]).to(device)))[0]
49
 
50
+ return ctc_output.strip(), rnnt_output.strip()
51
 
52
+ # Gradio interface
53
  with gr.Blocks() as demo:
54
+ gr.Markdown(f"## {DESCRIPTION}")
55
  with gr.Row():
56
  with gr.Column():
57
  audio = gr.Audio(label="Upload or record audio", type="filepath")