Kr08 commited on
Commit
e368f8b
·
verified ·
1 Parent(s): 03f8b40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -68
app.py CHANGED
@@ -1,77 +1,165 @@
 
1
  import torch
2
- import pickle
3
- import streamlit as st
 
 
 
4
 
5
- from io import BytesIO
6
- from config import SAMPLING_RATE
7
- from model_utils import load_models
8
- from audio_processing import detect_language, process_long_audio, load_and_resample_audio
 
 
9
 
10
- # Clear GPU cache
11
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Load models at startup
14
- load_models()
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Title of the app
17
- st.title("Audio Player with Live Transcription and Translation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Sidebar for file uploader and submit button
20
- st.sidebar.header("Upload Audio Files")
21
- uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True)
22
- submit_button = st.sidebar.button("Submit")
 
 
 
 
 
 
 
 
23
 
24
- # Session state to hold data
25
- if 'audio_files' not in st.session_state:
26
- st.session_state.audio_files = []
27
- st.session_state.transcriptions = {}
28
- st.session_state.translations = {}
29
- st.session_state.detected_languages = []
30
- st.session_state.waveforms = []
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Process uploaded files
33
- if submit_button and uploaded_files is not None:
34
- st.session_state.audio_files = uploaded_files
35
- st.session_state.detected_languages = []
36
- st.session_state.waveforms = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- for uploaded_file in uploaded_files:
39
- waveform = load_and_resample_audio(BytesIO(uploaded_file.read()))
40
- st.session_state.waveforms.append(waveform)
41
- detected_language = detect_language(waveform)
42
- st.session_state.detected_languages.append(detected_language)
43
-
44
- # Display uploaded files and options
45
- if 'audio_files' in st.session_state and st.session_state.audio_files:
46
- for i, uploaded_file in enumerate(st.session_state.audio_files):
47
- st.write(f"**File name**: {uploaded_file.name}")
48
- st.audio(uploaded_file, format=uploaded_file.type)
49
- st.write(f"**Detected Language**: {st.session_state.detected_languages[i]}")
50
-
51
- col1, col2 = st.columns(2)
52
-
53
- with col1:
54
- if st.button(f"Transcribe {uploaded_file.name}"):
55
- with st.spinner("Transcribing..."):
56
- transcription = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE)
57
- st.session_state.transcriptions[i] = transcription
58
-
59
- if st.session_state.transcriptions.get(i):
60
- st.write("**Transcription**:")
61
- st.text_area("", st.session_state.transcriptions[i], height=200, key=f"transcription_{i}")
62
- st.markdown(f'<div style="text-align: right;"><a href="data:text/plain;charset=UTF-8,{st.session_state.transcriptions[i]}" download="transcription_{uploaded_file.name}.txt">Download Transcription</a></div>', unsafe_allow_html=True)
63
-
64
- with col2:
65
- if st.button(f"Translate {uploaded_file.name}"):
66
- with st.spinner("Translating..."):
67
- with open('languages.pkl', 'rb') as f:
68
- lang_dict = pickle.load(f)
69
- detected_language_name = lang_dict[st.session_state.detected_languages[i]]
70
-
71
- translation = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE, task="translate",
72
- language=detected_language_name)
73
- st.session_state.translations[i] = translation
74
- if st.session_state.translations.get(i):
75
- st.write("**Translation**:")
76
- st.text_area("", st.session_state.translations[i], height=200, key=f"translation_{i}")
77
- st.markdown(f'<div style="text-align: right;"><a href="data:text/plain;charset=UTF-8,{st.session_state.translations[i]}" download="translation_{uploaded_file.name}.txt">Download Translation</a></div>', unsafe_allow_html=True)
 
1
+ import gradio as gr
2
  import torch
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import logging
5
+ import traceback
6
+ import sys
7
+ from audio_processing import AudioProcessor
8
 
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
12
+ handlers=[logging.StreamHandler(sys.stdout)]
13
+ )
14
+ logger = logging.getLogger(__name__)
15
 
16
+ def load_qa_model():
17
+ """Load question-answering model"""
18
+ try:
19
+ qa_pipeline = pipeline(
20
+ "text-generation",
21
+ model="meta-llama/Meta-Llama-3-8B-Instruct",
22
+ model_kwargs={"torch_dtype": torch.bfloat16},
23
+ device_map="auto",
24
+ )
25
+ return qa_pipeline
26
+ except Exception as e:
27
+ logger.error(f"Failed to load Q&A model: {str(e)}")
28
+ return None
29
 
30
+ def load_summarization_model():
31
+ """Load summarization model"""
32
+ try:
33
+ summarizer = pipeline(
34
+ "summarization",
35
+ model="sshleifer/distilbart-cnn-12-6",
36
+ device=0 if torch.cuda.is_available() else -1
37
+ )
38
+ return summarizer
39
+ except Exception as e:
40
+ logger.error(f"Failed to load summarization model: {str(e)}")
41
+ return None
42
 
43
+ def process_audio(audio_file, translate=False):
44
+ """Process audio file"""
45
+ try:
46
+ processor = AudioProcessor()
47
+ language_segments, final_segments = processor.process_audio(audio_file, translate)
48
+
49
+ # Format output
50
+ transcription = ""
51
+ full_text = ""
52
+
53
+ # Add language detection information
54
+ for segment in language_segments:
55
+ transcription += f"Language: {segment['language']}\n"
56
+ transcription += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n"
57
+
58
+ # Add transcription/translation information
59
+ transcription += "Transcription with language detection:\n\n"
60
+ for segment in final_segments:
61
+ transcription += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}):\n"
62
+ transcription += f"Original: {segment['text']}\n"
63
+ if translate and 'translated' in segment:
64
+ transcription += f"Translated: {segment['translated']}\n"
65
+ full_text += segment['translated'] + " "
66
+ else:
67
+ full_text += segment['text'] + " "
68
+ transcription += "\n"
69
+
70
+ return transcription, full_text
71
+
72
+ except Exception as e:
73
+ logger.error(f"Audio processing failed: {str(e)}")
74
+ raise gr.Error(f"Processing failed: {str(e)}")
75
 
76
+ def summarize_text(text):
77
+ """Summarize text"""
78
+ try:
79
+ summarizer = load_summarization_model()
80
+ if summarizer is None:
81
+ return "Summarization model could not be loaded."
82
+
83
+ summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
84
+ return summary
85
+ except Exception as e:
86
+ logger.error(f"Summarization failed: {str(e)}")
87
+ return "Error occurred during summarization."
88
 
89
+ def answer_question(context, question):
90
+ """Answer questions about the text"""
91
+ try:
92
+ qa_pipeline = load_qa_model()
93
+ if qa_pipeline is None:
94
+ return "Q&A model could not be loaded."
95
+
96
+ messages = [
97
+ {"role": "system", "content": "You are a helpful assistant who can answer questions based on the given context."},
98
+ {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
99
+ ]
100
+
101
+ response = qa_pipeline(messages, max_new_tokens=256)[0]['generated_text']
102
+ return response
103
+ except Exception as e:
104
+ logger.error(f"Q&A failed: {str(e)}")
105
+ return f"Error occurred during Q&A process: {str(e)}"
106
 
107
+ # Create Gradio interface
108
+ with gr.Blocks() as iface:
109
+ gr.Markdown("# Multilingual Speech Processing with MMS")
110
+
111
+ with gr.Row():
112
+ with gr.Column():
113
+ audio_input = gr.Audio(type="filepath")
114
+ translate_checkbox = gr.Checkbox(label="Enable Translation")
115
+ process_button = gr.Button("Process Audio")
116
+
117
+ with gr.Column():
118
+ transcription_output = gr.Textbox(label="Transcription/Translation", lines=10)
119
+ full_text_output = gr.Textbox(label="Full Text", lines=5)
120
+
121
+ with gr.Row():
122
+ with gr.Column():
123
+ summarize_button = gr.Button("Summarize")
124
+ summary_output = gr.Textbox(label="Summary", lines=3)
125
+
126
+ with gr.Column():
127
+ question_input = gr.Textbox(label="Ask a question about the transcription")
128
+ answer_button = gr.Button("Get Answer")
129
+ answer_output = gr.Textbox(label="Answer", lines=3)
130
+
131
+ # Set up event handlers
132
+ process_button.click(
133
+ process_audio,
134
+ inputs=[audio_input, translate_checkbox],
135
+ outputs=[transcription_output, full_text_output]
136
+ )
137
+
138
+ summarize_button.click(
139
+ summarize_text,
140
+ inputs=[full_text_output],
141
+ outputs=[summary_output]
142
+ )
143
+
144
+ answer_button.click(
145
+ answer_question,
146
+ inputs=[full_text_output, question_input],
147
+ outputs=[answer_output]
148
+ )
149
+
150
+ # Add system information
151
+ gr.Markdown(f"""
152
+ ## System Information
153
+ - Device: {"CUDA" if torch.cuda.is_available() else "CPU"}
154
+ - CUDA Available: {"Yes" if torch.cuda.is_available() else "No"}
155
+
156
+ ## Features
157
+ - Automatic language detection
158
+ - High-quality transcription using MMS
159
+ - Optional translation to English
160
+ - Text summarization
161
+ - Question answering
162
+ """)
163
 
164
+ if __name__ == "__main__":
165
+ iface.launch()