Kr08 commited on
Commit
6d2ca12
·
verified ·
1 Parent(s): f6ca2e3

Changed App UI, added transcription and translation buttons

Browse files
Files changed (1) hide show
  1. app.py +81 -32
app.py CHANGED
@@ -53,42 +53,91 @@ def detect_language(audio_file):
53
  print(f"Detected language: {max(probs[0], key=probs[0].get)}")
54
  return max(probs[0], key=probs[0].get)
55
 
56
- if submit_button and uploaded_files is not None:
57
- st.write("Files uploaded successfully!")
58
 
59
- for uploaded_file in uploaded_files:
60
- # Display file name and audio player
61
 
62
- st.write(f"**File name**: {uploaded_file.name}")
63
- st.audio(uploaded_file, format=uploaded_file.type)
64
 
65
- # Transcription section
66
- st.write("**Transcription**:")
67
 
68
- # Read the uploaded file data
69
- waveform, sampling_rate = ta.load(uploaded_file.getvalue())
70
- resampled_inp = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE)
71
 
72
- input_features = processor(resampled_inp[0], sampling_rate=16000, return_tensors='pt').input_features
73
 
74
- if task == "translate":
75
 
76
- # Detect Language
77
- lang = detect_language(input_features)
78
- with open('languages.pkl', 'rb') as f:
79
- lang_dict = pickle.load(f)
80
- detected_language = lang_dict[lang]
81
-
82
- # Set decoder & Predict translation
83
- forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language, task="translate")
84
- predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
85
- else:
86
- predicted_ids = model.generate(input_features)
87
- # decode token ids to text
88
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
89
- for i in range(len(transcription)):
90
- st.write(transcription[i])
91
- # print(waveform, sampling_rate)
92
- # Run transcription function and display
93
- # import pdb;pdb.set_trace()
94
- # st.write(audio_data.getvalue())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  print(f"Detected language: {max(probs[0], key=probs[0].get)}")
54
  return max(probs[0], key=probs[0].get)
55
 
56
+ # if submit_button and uploaded_files is not None:
57
+ # st.write("Files uploaded successfully!")
58
 
59
+ # for uploaded_file in uploaded_files:
60
+ # # Display file name and audio player
61
 
62
+ # st.write(f"**File name**: {uploaded_file.name}")
63
+ # st.audio(uploaded_file, format=uploaded_file.type)
64
 
65
+ # # Transcription section
66
+ # st.write("**Transcription**:")
67
 
68
+ # # Read the uploaded file data
69
+ # waveform, sampling_rate = ta.load(uploaded_file.getvalue())
70
+ # resampled_inp = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE)
71
 
72
+ # input_features = processor(resampled_inp[0], sampling_rate=16000, return_tensors='pt').input_features
73
 
74
+ # if task == "translate":
75
 
76
+ # # Detect Language
77
+ # lang = detect_language(input_features)
78
+ # with open('languages.pkl', 'rb') as f:
79
+ # lang_dict = pickle.load(f)
80
+ # detected_language = lang_dict[lang]
81
+
82
+ # # Set decoder & Predict translation
83
+ # forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language, task="translate")
84
+ # predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
85
+ # else:
86
+ # predicted_ids = model.generate(input_features)
87
+ # # decode token ids to text
88
+ # transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
89
+ # for i in range(len(transcription)):
90
+ # st.write(transcription[i])
91
+ # # print(waveform, sampling_rate)
92
+ # # Run transcription function and display
93
+ # # import pdb;pdb.set_trace()
94
+ # # st.write(audio_data.getvalue())
95
+
96
+
97
+
98
+ if submit_button and uploaded_files is not None:
99
+ # Initialize a list to store detected languages
100
+ detected_languages = []
101
+
102
+ for uploaded_file in uploaded_files:
103
+ # Read the uploaded file data
104
+ waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read()))
105
+
106
+ # Resample if necessary
107
+ if sampling_rate != SAMPLING_RATE:
108
+ waveform = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE)
109
+
110
+ # Detect language
111
+ detected_language = detect_language(waveform, SAMPLING_RATE)
112
+ detected_languages.append(detected_language)
113
+
114
+ # Display each uploaded file with its detected language and an audio player
115
+ for i, uploaded_file in enumerate(uploaded_files):
116
+ col1, col2 = st.columns([1, 3]) # Two columns, one for the player, one for the buttons
117
+
118
+ with col1:
119
+ st.write(f"**File name**: {uploaded_file.name}")
120
+ st.audio(BytesIO(uploaded_file.getvalue()), format=uploaded_file.type)
121
+ st.write(f"**Detected Language**: {detected_languages[i]}")
122
+
123
+ with col2:
124
+ # Add Transcription and Translation buttons
125
+ if st.button(f"Transcribe {uploaded_file.name}"):
126
+ # Transcription process
127
+ input_features = processor(waveform[0], sampling_rate=SAMPLING_RATE, return_tensors='pt').input_features
128
+ predicted_ids = model.generate(input_features)
129
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
130
+ for line in transcription:
131
+ st.write(line)
132
+
133
+ if st.button(f"Translate {uploaded_file.name}"):
134
+ # Translation process
135
+ with open('languages.pkl', 'rb') as f:
136
+ lang_dict = pickle.load(f)
137
+ detected_language_name = lang_dict[detected_languages[i]]
138
+
139
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language_name, task="translate")
140
+ predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
141
+ translation = processor.batch_decode(predicted_ids, skip_special_tokens=True)
142
+ for line in translation:
143
+ st.write(line)