awsaf49 commited on
Commit
efdd97c
·
verified ·
1 Parent(s): 3085c15

add: option for model and duration separately

Browse files
Files changed (1) hide show
  1. app.py +54 -29
app.py CHANGED
@@ -5,35 +5,42 @@ import numpy as np
5
  import gradio as gr
6
  from sonics import HFAudioClassifier
7
 
8
- # Model configurations
9
- MODEL_IDS = {
10
- "SpecTTTra-α (5s)": "awsaf49/sonics-spectttra-alpha-5s",
11
- "SpecTTTra-β (5s)": "awsaf49/sonics-spectttra-beta-5s",
12
- "SpecTTTra-γ (5s)": "awsaf49/sonics-spectttra-gamma-5s",
13
- "SpecTTTra-α (120s)": "awsaf49/sonics-spectttra-alpha-120s",
14
- "SpecTTTra-β (120s)": "awsaf49/sonics-spectttra-beta-120s",
15
- "SpecTTTra-γ (120s)": "awsaf49/sonics-spectttra-gamma-120s",
16
- }
 
 
 
 
 
 
 
17
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model_cache = {}
20
 
21
-
22
- def load_model(model_name):
23
  """Load model if not already cached"""
24
- if model_name not in model_cache:
25
- model_id = MODEL_IDS[model_name]
 
26
  model = HFAudioClassifier.from_pretrained(model_id)
27
  model = model.to(device)
28
  model.eval()
29
- model_cache[model_name] = model
30
- return model_cache[model_name]
31
 
32
 
33
- def process_audio(audio_path, model_name):
34
  """Process audio file and return prediction"""
35
  try:
36
- model = load_model(model_name)
37
  max_time = model.config.audio.max_time
38
 
39
  # Load and process audio
@@ -69,11 +76,11 @@ def process_audio(audio_path, model_name):
69
  return {"Error": str(e)}
70
 
71
 
72
- def predict(audio_file, model_name):
73
  """Gradio interface function"""
74
  if audio_file is None:
75
  return {"Message": "Please upload an audio file"}
76
- return process_audio(audio_file, model_name)
77
 
78
 
79
  # Updated CSS with better color scheme for resource links
@@ -146,6 +153,15 @@ css = """
146
  margin-top: 30px;
147
  padding: 15px;
148
  }
 
 
 
 
 
 
 
 
 
149
  """
150
 
151
  # Create Gradio interface
@@ -199,12 +215,21 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean()) as demo:
199
  elem_id="audio_input"
200
  )
201
 
202
- model_dropdown = gr.Dropdown(
203
- choices=list(MODEL_IDS.keys()),
204
- value="SpecTTTra-γ (5s)",
205
- label="Select Model",
206
- elem_id="model_dropdown"
207
- )
 
 
 
 
 
 
 
 
 
208
 
209
  submit_btn = gr.Button(
210
  "✨ Analyze Audio",
@@ -240,10 +265,10 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean()) as demo:
240
  with gr.Accordion("Example Audio Files", open=True):
241
  gr.Examples(
242
  examples=[
243
- ["example/real_song.mp3", "SpecTTTra-γ (5s)"],
244
- ["example/fake_song.mp3", "SpecTTTra-γ (5s)"],
245
  ],
246
- inputs=[audio_input, model_dropdown],
247
  outputs=[output],
248
  fn=predict,
249
  cache_examples=True,
@@ -260,7 +285,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean()) as demo:
260
  )
261
 
262
  # Prediction handling
263
- submit_btn.click(fn=predict, inputs=[audio_input, model_dropdown], outputs=[output])
264
 
265
  if __name__ == "__main__":
266
  demo.launch()
 
5
  import gradio as gr
6
  from sonics import HFAudioClassifier
7
 
8
+ # Restructured model configurations for separate selectors
9
+ MODEL_TYPES = ["SpecTTTra-α", "SpecTTTra-β", "SpecTTTra-γ"]
10
+ DURATIONS = ["5s", "120s"]
11
+
12
+ # Mapping for model IDs
13
+ def get_model_id(model_type, duration):
14
+ model_map = {
15
+ "SpecTTTra-α-5s": "awsaf49/sonics-spectttra-alpha-5s",
16
+ "SpecTTTra-β-5s": "awsaf49/sonics-spectttra-beta-5s",
17
+ "SpecTTTra-γ-5s": "awsaf49/sonics-spectttra-gamma-5s",
18
+ "SpecTTTra-α-120s": "awsaf49/sonics-spectttra-alpha-120s",
19
+ "SpecTTTra-β-120s": "awsaf49/sonics-spectttra-beta-120s",
20
+ "SpecTTTra-γ-120s": "awsaf49/sonics-spectttra-gamma-120s",
21
+ }
22
+ key = f"{model_type}-{duration}"
23
+ return model_map[key]
24
 
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
  model_cache = {}
27
 
28
+ def load_model(model_type, duration):
 
29
  """Load model if not already cached"""
30
+ model_key = f"{model_type}-{duration}"
31
+ if model_key not in model_cache:
32
+ model_id = get_model_id(model_type, duration)
33
  model = HFAudioClassifier.from_pretrained(model_id)
34
  model = model.to(device)
35
  model.eval()
36
+ model_cache[model_key] = model
37
+ return model_cache[model_key]
38
 
39
 
40
+ def process_audio(audio_path, model_type, duration):
41
  """Process audio file and return prediction"""
42
  try:
43
+ model = load_model(model_type, duration)
44
  max_time = model.config.audio.max_time
45
 
46
  # Load and process audio
 
76
  return {"Error": str(e)}
77
 
78
 
79
+ def predict(audio_file, model_type, duration):
80
  """Gradio interface function"""
81
  if audio_file is None:
82
  return {"Message": "Please upload an audio file"}
83
+ return process_audio(audio_file, model_type, duration)
84
 
85
 
86
  # Updated CSS with better color scheme for resource links
 
153
  margin-top: 30px;
154
  padding: 15px;
155
  }
156
+
157
+ /* Selectors wrapper for side-by-side appearance */
158
+ .selectors-wrapper {
159
+ display: flex;
160
+ gap: 10px;
161
+ }
162
+ .selectors-wrapper > div {
163
+ flex: 1;
164
+ }
165
  """
166
 
167
  # Create Gradio interface
 
215
  elem_id="audio_input"
216
  )
217
 
218
+ # Add CSS class to create a wrapper for side-by-side dropdowns
219
+ with gr.Row(elem_classes="selectors-wrapper"):
220
+ model_dropdown = gr.Dropdown(
221
+ choices=MODEL_TYPES,
222
+ value="SpecTTTra-γ",
223
+ label="Select Model",
224
+ elem_id="model_dropdown"
225
+ )
226
+
227
+ duration_dropdown = gr.Dropdown(
228
+ choices=DURATIONS,
229
+ value="5s",
230
+ label="Select Duration",
231
+ elem_id="duration_dropdown"
232
+ )
233
 
234
  submit_btn = gr.Button(
235
  "✨ Analyze Audio",
 
265
  with gr.Accordion("Example Audio Files", open=True):
266
  gr.Examples(
267
  examples=[
268
+ ["example/real_song.mp3", "SpecTTTra-γ", "5s"],
269
+ ["example/fake_song.mp3", "SpecTTTra-γ", "5s"],
270
  ],
271
+ inputs=[audio_input, model_dropdown, duration_dropdown],
272
  outputs=[output],
273
  fn=predict,
274
  cache_examples=True,
 
285
  )
286
 
287
  # Prediction handling
288
+ submit_btn.click(fn=predict, inputs=[audio_input, model_dropdown, duration_dropdown], outputs=[output])
289
 
290
  if __name__ == "__main__":
291
  demo.launch()