initial commit
Browse files- app.py +95 -0
- asr.py +41 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import datetime
|
4 |
+
import streamlit as st
|
5 |
+
from asr import load_model, inference
|
6 |
+
from audio_recorder_streamlit import audio_recorder
|
7 |
+
|
8 |
+
|
9 |
+
@st.cache_resource
|
10 |
+
def load_asr_model():
|
11 |
+
return load_model()
|
12 |
+
|
13 |
+
processor, asr_model = load_asr_model()
|
14 |
+
|
15 |
+
|
16 |
+
def save_audio_file(audio_bytes, file_extension):
|
17 |
+
"""
|
18 |
+
Save audio bytes to a file with the specified extension.
|
19 |
+
|
20 |
+
:param audio_bytes: Audio data in bytes
|
21 |
+
:param file_extension: The extension of the output audio file
|
22 |
+
:return: The name of the saved audio file
|
23 |
+
"""
|
24 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
25 |
+
file_name = f"audio_{timestamp}.{file_extension}"
|
26 |
+
|
27 |
+
with open(file_name, "wb") as f:
|
28 |
+
f.write(audio_bytes)
|
29 |
+
|
30 |
+
return file_name
|
31 |
+
|
32 |
+
|
33 |
+
def transcribe_audio(file_path):
|
34 |
+
"""
|
35 |
+
Transcribe the audio file at the specified path.
|
36 |
+
|
37 |
+
:param file_path: The path of the audio file to transcribe
|
38 |
+
:return: The transcribed text
|
39 |
+
"""
|
40 |
+
with open(file_path, "rb") as audio_file:
|
41 |
+
transcript = inference(processor, asr_model, audio_file)
|
42 |
+
return transcript
|
43 |
+
|
44 |
+
|
45 |
+
def main():
|
46 |
+
"""
|
47 |
+
"""
|
48 |
+
st.title("Anishinaabemowin Transcription")
|
49 |
+
tab1, tab2 = st.tabs(["Record Audio", "Upload Audio"])
|
50 |
+
|
51 |
+
# Record Audio tab
|
52 |
+
with tab1:
|
53 |
+
audio_bytes = audio_recorder()
|
54 |
+
if audio_bytes:
|
55 |
+
st.audio(audio_bytes, format="audio/wav")
|
56 |
+
fname = save_audio_file(audio_bytes, "wav")
|
57 |
+
|
58 |
+
# Upload Audio tab
|
59 |
+
with tab2:
|
60 |
+
audio_file = st.file_uploader("Upload Audio", type=["wav"])
|
61 |
+
if audio_file:
|
62 |
+
file_extension = audio_file.type.split('/')[1]
|
63 |
+
fname = save_audio_file(audio_file.read(), file_extension)
|
64 |
+
|
65 |
+
# Transcribe button action
|
66 |
+
if st.button("Transcribe"):
|
67 |
+
# Find the newest audio file
|
68 |
+
#audio_file_path = max(
|
69 |
+
# [f for f in os.listdir(".") if f.startswith("audio")],
|
70 |
+
# key=os.path.getctime,
|
71 |
+
#)
|
72 |
+
|
73 |
+
|
74 |
+
# Transcribe the audio file
|
75 |
+
transcript_text = transcribe_audio(fname)
|
76 |
+
|
77 |
+
# Display the transcript
|
78 |
+
st.header("Transcript")
|
79 |
+
st.write(transcript_text)
|
80 |
+
|
81 |
+
# Save the transcript to a text file
|
82 |
+
with open("transcript.txt", "w") as f:
|
83 |
+
f.write(transcript_text)
|
84 |
+
|
85 |
+
# Provide a download button for the transcript
|
86 |
+
st.download_button("Download Transcript", transcript_text)
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
# Set up the working directory
|
91 |
+
working_dir = os.path.dirname(os.path.abspath(__file__))
|
92 |
+
sys.path.append(working_dir)
|
93 |
+
|
94 |
+
# Run the main function
|
95 |
+
main()
|
asr.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Wav2Vec2ForCTC, AutoProcessor
|
2 |
+
import torchaudio
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import librosa
|
6 |
+
|
7 |
+
hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
|
8 |
+
|
9 |
+
def read_audio_data(file):
|
10 |
+
speech_array, sampling_rate = torchaudio.load(file, normalize = True)
|
11 |
+
return speech_array, sampling_rate
|
12 |
+
|
13 |
+
def load_model():
|
14 |
+
model_id = "Lguyogiro/wav2vec2-large-mms-1b-nhi-adapterft-ilv_fold1"
|
15 |
+
target_lang = "nhi"
|
16 |
+
processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang, use_auth_token=hf_token)
|
17 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True, use_safetensors=True, use_auth_token=hf_token)
|
18 |
+
return processor, model
|
19 |
+
|
20 |
+
|
21 |
+
def inference(processor, model, audio_path):
|
22 |
+
audio, sampling_rate = librosa.load(audio_path, sr=16000) # Ensure the correct sampling rate
|
23 |
+
inputs = processor(audio, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
|
24 |
+
|
25 |
+
with torch.no_grad():
|
26 |
+
logits = model(inputs.input_values).logits
|
27 |
+
|
28 |
+
# Decode predicted tokens
|
29 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
30 |
+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
31 |
+
|
32 |
+
|
33 |
+
#arr, rate = read_audio_data(audio_path)
|
34 |
+
#inputs = processor(arr.squeeze().numpy(), sampling_rate=16_000, return_tensors="pt")
|
35 |
+
|
36 |
+
#with torch.no_grad():
|
37 |
+
# outputs = model(**inputs).logits
|
38 |
+
#ids = torch.argmax(outputs, dim=-1)[0]
|
39 |
+
#transcription = processor.decode(ids)
|
40 |
+
|
41 |
+
return transcription
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
soundfile
|
2 |
+
transformers
|
3 |
+
torch
|
4 |
+
torchaudio
|
5 |
+
streamlit_webrtc
|
6 |
+
audio_recorder_streamlit
|
7 |
+
librosa
|