owiedotch commited on
Commit
16120e1
·
verified ·
1 Parent(s): dc59c25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -5
app.py CHANGED
@@ -10,11 +10,12 @@ import tempfile
10
  import io
11
  import uuid
12
  import pickle
 
13
  from pathlib import Path
14
 
15
  # Initialize the model and ensure it's on the correct device
16
  def load_model():
17
- model = SemantiCodec(token_rate=100, semantic_vocab_size=32768) # 1.40 kbps
18
  if torch.cuda.is_available():
19
  # Move the model to CUDA
20
  model.to("cuda:0")
@@ -26,6 +27,9 @@ semanticodec = load_model()
26
  model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
27
  print(f"Model initialized on device: {model_device}")
28
 
 
 
 
29
  @spaces.GPU(duration=20)
30
  def encode_audio(audio_path):
31
  """Encode audio file to tokens and return them as a file"""
@@ -106,12 +110,11 @@ def decode_tokens(token_file):
106
 
107
  # Extract audio data - this should be a numpy array
108
  audio_data = waveform[0, 0] # Shape should be [time]
109
- sample_rate = 16000
110
 
111
  print(f"Audio data shape: {audio_data.shape}, dtype: {audio_data.dtype}")
112
 
113
  # Return in Gradio Audio compatible format: (sample_rate, audio_data)
114
- return (sample_rate, audio_data), f"Decoded {tokens.shape[1]} tokens to audio"
115
  except Exception as e:
116
  print(f"Decoding error: {str(e)}")
117
  return None, f"Error decoding tokens: {str(e)}"
@@ -155,16 +158,98 @@ def process_both(audio_path):
155
 
156
  # Extract audio data - this should be a numpy array
157
  audio_data = waveform[0, 0] # Shape should be [time]
158
- sample_rate = 16000
159
 
160
  print(f"Audio data shape: {audio_data.shape}, dtype: {audio_data.dtype}")
161
 
162
  # Return in Gradio Audio compatible format: (sample_rate, audio_data)
163
- return (sample_rate, audio_data), f"Encoded to {tokens.shape[1]} tokens\nDecoded {tokens.shape[1]} tokens to audio"
164
  except Exception as e:
165
  print(f"Processing error: {str(e)}")
166
  return None, f"Error processing audio: {str(e)}"
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # Create Gradio interface
169
  with gr.Blocks(title="Oterin Audio Codec") as demo:
170
  gr.Markdown("# Oterin Audio Codec")
@@ -186,6 +271,19 @@ with gr.Blocks(title="Oterin Audio Codec") as demo:
186
  decode_btn = gr.Button("Decode")
187
  decode_btn.click(decode_tokens, inputs=decode_input, outputs=[decode_output, decode_status])
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  with gr.Tab("Both (Encode & Decode)"):
190
  with gr.Row():
191
  both_input = gr.Audio(type="filepath", label="Input Audio")
 
10
  import io
11
  import uuid
12
  import pickle
13
+ import time
14
  from pathlib import Path
15
 
16
  # Initialize the model and ensure it's on the correct device
17
  def load_model():
18
+ model = SemantiCodec(token_rate=100, semantic_vocab_size=16384) # 1.35 kbps
19
  if torch.cuda.is_available():
20
  # Move the model to CUDA
21
  model.to("cuda:0")
 
27
  model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
28
  print(f"Model initialized on device: {model_device}")
29
 
30
+ # Define sample rate as a constant
31
+ SAMPLE_RATE = 32000
32
+
33
  @spaces.GPU(duration=20)
34
  def encode_audio(audio_path):
35
  """Encode audio file to tokens and return them as a file"""
 
110
 
111
  # Extract audio data - this should be a numpy array
112
  audio_data = waveform[0, 0] # Shape should be [time]
 
113
 
114
  print(f"Audio data shape: {audio_data.shape}, dtype: {audio_data.dtype}")
115
 
116
  # Return in Gradio Audio compatible format: (sample_rate, audio_data)
117
+ return (SAMPLE_RATE, audio_data), f"Decoded {tokens.shape[1]} tokens to audio"
118
  except Exception as e:
119
  print(f"Decoding error: {str(e)}")
120
  return None, f"Error decoding tokens: {str(e)}"
 
158
 
159
  # Extract audio data - this should be a numpy array
160
  audio_data = waveform[0, 0] # Shape should be [time]
 
161
 
162
  print(f"Audio data shape: {audio_data.shape}, dtype: {audio_data.dtype}")
163
 
164
  # Return in Gradio Audio compatible format: (sample_rate, audio_data)
165
+ return (SAMPLE_RATE, audio_data), f"Encoded to {tokens.shape[1]} tokens\nDecoded {tokens.shape[1]} tokens to audio"
166
  except Exception as e:
167
  print(f"Processing error: {str(e)}")
168
  return None, f"Error processing audio: {str(e)}"
169
 
170
+ @spaces.GPU(duration=360)
171
+ def stream_decode_tokens(token_file):
172
+ """Decode tokens to audio in streaming chunks"""
173
+ # Ensure the file exists and has content
174
+ if not token_file or not os.path.exists(token_file):
175
+ yield None, "Error: Empty or missing token file"
176
+ return
177
+
178
+ try:
179
+ # Load tokens using pickle instead of numpy load
180
+ with open(token_file, "rb") as f:
181
+ token_data = pickle.load(f)
182
+
183
+ tokens = token_data['tokens']
184
+ intended_device = token_data.get('device', model_device)
185
+ print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
186
+
187
+ # If tokens are too small, decode all at once
188
+ if tokens.shape[1] < 500:
189
+ # Convert to torch tensor with Long dtype for embedding
190
+ tokens_tensor = torch.tensor(tokens, dtype=torch.long)
191
+ tokens_tensor = tokens_tensor.to(model_device)
192
+
193
+ # Decode the tokens
194
+ waveform = semanticodec.decode(tokens_tensor)
195
+ if isinstance(waveform, torch.Tensor):
196
+ waveform = waveform.cpu().numpy()
197
+
198
+ audio_data = waveform[0, 0]
199
+ yield (SAMPLE_RATE, audio_data), f"Decoded {tokens.shape[1]} tokens to audio"
200
+ return
201
+
202
+ # Split tokens into chunks for streaming
203
+ chunk_size = 500 # Number of tokens per chunk
204
+ num_chunks = (tokens.shape[1] + chunk_size - 1) // chunk_size # Ceiling division
205
+
206
+ # First status update
207
+ yield None, f"Starting decoding of {tokens.shape[1]} tokens in {num_chunks} chunks..."
208
+
209
+ all_audio_chunks = []
210
+
211
+ for i in range(num_chunks):
212
+ start_idx = i * chunk_size
213
+ end_idx = min((i + 1) * chunk_size, tokens.shape[1])
214
+
215
+ print(f"Decoding chunk {i+1}/{num_chunks}, tokens {start_idx} to {end_idx}")
216
+
217
+ # Extract chunk of tokens
218
+ token_chunk = tokens[:, start_idx:end_idx, :]
219
+
220
+ # Convert to torch tensor with Long dtype
221
+ tokens_tensor = torch.tensor(token_chunk, dtype=torch.long)
222
+ tokens_tensor = tokens_tensor.to(model_device)
223
+
224
+ # Ensure model is on the expected device
225
+ semanticodec.to(model_device)
226
+
227
+ # Decode the tokens
228
+ waveform = semanticodec.decode(tokens_tensor)
229
+ if isinstance(waveform, torch.Tensor):
230
+ waveform = waveform.cpu().numpy()
231
+
232
+ # Extract audio data
233
+ audio_chunk = waveform[0, 0]
234
+ all_audio_chunks.append(audio_chunk)
235
+
236
+ # Combine all chunks we have so far
237
+ combined_audio = np.concatenate(all_audio_chunks)
238
+
239
+ # Yield the combined audio for streaming playback
240
+ yield (SAMPLE_RATE, combined_audio), f"Decoded chunk {i+1}/{num_chunks} ({end_idx}/{tokens.shape[1]} tokens)"
241
+
242
+ # Small delay to allow Gradio to update UI
243
+ time.sleep(0.1)
244
+
245
+ # Final complete audio
246
+ combined_audio = np.concatenate(all_audio_chunks)
247
+ yield (SAMPLE_RATE, combined_audio), f"Completed decoding all {tokens.shape[1]} tokens"
248
+
249
+ except Exception as e:
250
+ print(f"Streaming decode error: {str(e)}")
251
+ yield None, f"Error decoding tokens: {str(e)}"
252
+
253
  # Create Gradio interface
254
  with gr.Blocks(title="Oterin Audio Codec") as demo:
255
  gr.Markdown("# Oterin Audio Codec")
 
271
  decode_btn = gr.Button("Decode")
272
  decode_btn.click(decode_tokens, inputs=decode_input, outputs=[decode_output, decode_status])
273
 
274
+ with gr.Tab("Stream Decode (Listen while decoding)"):
275
+ with gr.Row():
276
+ stream_decode_input = gr.File(label="Token File (.oterin)", file_types=[".oterin"])
277
+ stream_decode_output = gr.Audio(label="Streaming Audio Output")
278
+ stream_decode_status = gr.Textbox(label="Status")
279
+ stream_decode_btn = gr.Button("Start Streaming Decode")
280
+ stream_decode_btn.click(
281
+ stream_decode_tokens,
282
+ inputs=stream_decode_input,
283
+ outputs=[stream_decode_output, stream_decode_status],
284
+ show_progress=True
285
+ )
286
+
287
  with gr.Tab("Both (Encode & Decode)"):
288
  with gr.Row():
289
  both_input = gr.Audio(type="filepath", label="Input Audio")