pratikshahp commited on
Commit
16d11ec
·
verified ·
1 Parent(s): a976d7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -3,12 +3,16 @@ from transformers import Speech2TextProcessor, Speech2TextForConditionalGenerati
3
  from audio_recorder_streamlit import audio_recorder
4
  import numpy as np
5
  import streamlit as st
 
6
  def transcribe_audio(audio_bytes):
7
-
8
  model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-mustc-en-fr-st")
9
  processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-mustc-en-fr-st")
10
 
11
- generated_ids = model.generate(input_ids=audio_bytes["input_features"], attention_mask=audio_bytes["attention_mask"])
 
 
 
 
12
  translation = processor.batch_decode(generated_ids, skip_special_tokens=True)
13
 
14
  return translation
@@ -25,4 +29,4 @@ if audio_bytes:
25
  else:
26
  st.write("Error: Failed to transcribe audio.")
27
  else:
28
- st.write("No audio recorded.")
 
3
  from audio_recorder_streamlit import audio_recorder
4
  import numpy as np
5
  import streamlit as st
6
+
7
  def transcribe_audio(audio_bytes):
 
8
  model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-mustc-en-fr-st")
9
  processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-mustc-en-fr-st")
10
 
11
+ # Convert audio bytes to tensors
12
+ input_features = torch.tensor(audio_bytes).unsqueeze(0) # Assuming audio_bytes is numpy array
13
+
14
+ # Generate transcription
15
+ generated_ids = model.generate(input_features)
16
  translation = processor.batch_decode(generated_ids, skip_special_tokens=True)
17
 
18
  return translation
 
29
  else:
30
  st.write("Error: Failed to transcribe audio.")
31
  else:
32
+ st.write("No audio recorded.")