Lguyogiro commited on
Commit
f4e441a
·
1 Parent(s): a6be9e9

initial commit

Browse files
Files changed (3) hide show
  1. app.py +95 -0
  2. asr.py +41 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import datetime
4
+ import streamlit as st
5
+ from asr import load_model, inference
6
+ from audio_recorder_streamlit import audio_recorder
7
+
8
+
9
+ @st.cache_resource
10
+ def load_asr_model():
11
+ return load_model()
12
+
13
+ processor, asr_model = load_asr_model()
14
+
15
+
16
+ def save_audio_file(audio_bytes, file_extension):
17
+ """
18
+ Save audio bytes to a file with the specified extension.
19
+
20
+ :param audio_bytes: Audio data in bytes
21
+ :param file_extension: The extension of the output audio file
22
+ :return: The name of the saved audio file
23
+ """
24
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
25
+ file_name = f"audio_{timestamp}.{file_extension}"
26
+
27
+ with open(file_name, "wb") as f:
28
+ f.write(audio_bytes)
29
+
30
+ return file_name
31
+
32
+
33
+ def transcribe_audio(file_path):
34
+ """
35
+ Transcribe the audio file at the specified path.
36
+
37
+ :param file_path: The path of the audio file to transcribe
38
+ :return: The transcribed text
39
+ """
40
+ with open(file_path, "rb") as audio_file:
41
+ transcript = inference(processor, asr_model, audio_file)
42
+ return transcript
43
+
44
+
45
+ def main():
46
+ """
47
+ """
48
+ st.title("Anishinaabemowin Transcription")
49
+ tab1, tab2 = st.tabs(["Record Audio", "Upload Audio"])
50
+
51
+ # Record Audio tab
52
+ with tab1:
53
+ audio_bytes = audio_recorder()
54
+ if audio_bytes:
55
+ st.audio(audio_bytes, format="audio/wav")
56
+ fname = save_audio_file(audio_bytes, "wav")
57
+
58
+ # Upload Audio tab
59
+ with tab2:
60
+ audio_file = st.file_uploader("Upload Audio", type=["wav"])
61
+ if audio_file:
62
+ file_extension = audio_file.type.split('/')[1]
63
+ fname = save_audio_file(audio_file.read(), file_extension)
64
+
65
+ # Transcribe button action
66
+ if st.button("Transcribe"):
67
+ # Find the newest audio file
68
+ #audio_file_path = max(
69
+ # [f for f in os.listdir(".") if f.startswith("audio")],
70
+ # key=os.path.getctime,
71
+ #)
72
+
73
+
74
+ # Transcribe the audio file
75
+ transcript_text = transcribe_audio(fname)
76
+
77
+ # Display the transcript
78
+ st.header("Transcript")
79
+ st.write(transcript_text)
80
+
81
+ # Save the transcript to a text file
82
+ with open("transcript.txt", "w") as f:
83
+ f.write(transcript_text)
84
+
85
+ # Provide a download button for the transcript
86
+ st.download_button("Download Transcript", transcript_text)
87
+
88
+
89
+ if __name__ == "__main__":
90
+ # Set up the working directory
91
+ working_dir = os.path.dirname(os.path.abspath(__file__))
92
+ sys.path.append(working_dir)
93
+
94
+ # Run the main function
95
+ main()
asr.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2ForCTC, AutoProcessor
2
+ import torchaudio
3
+ import torch
4
+ import os
5
+ import librosa
6
+
7
+ hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
8
+
9
+ def read_audio_data(file):
10
+ speech_array, sampling_rate = torchaudio.load(file, normalize = True)
11
+ return speech_array, sampling_rate
12
+
13
+ def load_model():
14
+ model_id = "Lguyogiro/wav2vec2-large-mms-1b-nhi-adapterft-ilv_fold1"
15
+ target_lang = "nhi"
16
+ processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang, use_auth_token=hf_token)
17
+ model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True, use_safetensors=True, use_auth_token=hf_token)
18
+ return processor, model
19
+
20
+
21
+ def inference(processor, model, audio_path):
22
+ audio, sampling_rate = librosa.load(audio_path, sr=16000) # Ensure the correct sampling rate
23
+ inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
24
+
25
+ with torch.no_grad():
26
+ logits = model(inputs.input_values).logits
27
+
28
+ # Decode predicted tokens
29
+ predicted_ids = torch.argmax(logits, dim=-1)
30
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
31
+
32
+
33
+ #arr, rate = read_audio_data(audio_path)
34
+ #inputs = processor(arr.squeeze().numpy(), sampling_rate=16_000, return_tensors="pt")
35
+
36
+ #with torch.no_grad():
37
+ # outputs = model(**inputs).logits
38
+ #ids = torch.argmax(outputs, dim=-1)[0]
39
+ #transcription = processor.decode(ids)
40
+
41
+ return transcription
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ soundfile
2
+ transformers
3
+ torch
4
+ torchaudio
5
+ streamlit_webrtc
6
+ audio_recorder_streamlit
7
+ librosa