Update app.py
Browse files
app.py
CHANGED
@@ -158,6 +158,7 @@ def stage2_generate(model_stage2, prompt, batch_size=16):
|
|
158 |
Given a prompt (a numpy array of raw codec ids), upsample using the Stage2 model.
|
159 |
"""
|
160 |
# Unflatten prompt: assume prompt shape (1, T) and then reformat.
|
|
|
161 |
codec_ids = codectool.unflatten(prompt, n_quantizer=1)
|
162 |
codec_ids = codectool.offset_tok_ids(
|
163 |
codec_ids,
|
@@ -238,10 +239,20 @@ def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_s
|
|
238 |
print(f"{output_filename} already processed.")
|
239 |
stage2_result.append(output_filename)
|
240 |
continue
|
|
|
241 |
prompt = np.load(path).astype(np.int32)
|
242 |
-
#
|
|
|
|
|
|
|
|
|
|
|
243 |
output_duration = (prompt.shape[-1] // 50) // 6 * 6
|
|
|
|
|
244 |
num_batch = output_duration // 6
|
|
|
|
|
245 |
if num_batch <= batch_size:
|
246 |
output = stage2_generate(model_stage2, prompt[:, :output_duration*50], batch_size=num_batch)
|
247 |
else:
|
@@ -251,16 +262,25 @@ def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_s
|
|
251 |
start_idx = seg * batch_size * 300
|
252 |
end_idx = min((seg + 1) * batch_size * 300, output_duration * 50)
|
253 |
current_batch = batch_size if (seg != num_segments - 1 or num_batch % batch_size == 0) else num_batch % batch_size
|
254 |
-
|
|
|
|
|
|
|
|
|
255 |
segments.append(segment)
|
|
|
|
|
256 |
output = np.concatenate(segments, axis=0)
|
|
|
257 |
# Process any remaining tokens if prompt length not fully used.
|
258 |
if output_duration * 50 != prompt.shape[-1]:
|
259 |
ending = stage2_generate(model_stage2, prompt[:, output_duration * 50:], batch_size=1)
|
260 |
output = np.concatenate([output, ending], axis=0)
|
261 |
-
|
|
|
262 |
output = codectool_stage2.ids2npy(output)
|
263 |
-
|
|
|
264 |
fixed_output = copy.deepcopy(output)
|
265 |
for i, line in enumerate(output):
|
266 |
for j, element in enumerate(line):
|
@@ -268,6 +288,7 @@ def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_s
|
|
268 |
counter = Counter(line)
|
269 |
most_common = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
|
270 |
fixed_output[i, j] = most_common
|
|
|
271 |
np.save(output_filename, fixed_output)
|
272 |
stage2_result.append(output_filename)
|
273 |
return stage2_result
|
|
|
158 |
Given a prompt (a numpy array of raw codec ids), upsample using the Stage2 model.
|
159 |
"""
|
160 |
# Unflatten prompt: assume prompt shape (1, T) and then reformat.
|
161 |
+
print(f"stage2_generate: received prompt with shape: {prompt.shape}")
|
162 |
codec_ids = codectool.unflatten(prompt, n_quantizer=1)
|
163 |
codec_ids = codectool.offset_tok_ids(
|
164 |
codec_ids,
|
|
|
239 |
print(f"{output_filename} already processed.")
|
240 |
stage2_result.append(output_filename)
|
241 |
continue
|
242 |
+
|
243 |
prompt = np.load(path).astype(np.int32)
|
244 |
+
# Ensure prompt is 2D.
|
245 |
+
if prompt.ndim == 1:
|
246 |
+
prompt = prompt[np.newaxis, :]
|
247 |
+
print(f"Loaded prompt from {path} with shape: {prompt.shape}")
|
248 |
+
|
249 |
+
# Compute output duration: tokens per second assumed to be 50, only full 6-second segments.
|
250 |
output_duration = (prompt.shape[-1] // 50) // 6 * 6
|
251 |
+
if output_duration == 0:
|
252 |
+
raise ValueError(f"Output duration computed as 0 for {path}. Prompt length: {prompt.shape[-1]} tokens")
|
253 |
num_batch = output_duration // 6
|
254 |
+
|
255 |
+
# Process prompt in batches
|
256 |
if num_batch <= batch_size:
|
257 |
output = stage2_generate(model_stage2, prompt[:, :output_duration*50], batch_size=num_batch)
|
258 |
else:
|
|
|
262 |
start_idx = seg * batch_size * 300
|
263 |
end_idx = min((seg + 1) * batch_size * 300, output_duration * 50)
|
264 |
current_batch = batch_size if (seg != num_segments - 1 or num_batch % batch_size == 0) else num_batch % batch_size
|
265 |
+
segment_prompt = prompt[:, start_idx:end_idx]
|
266 |
+
if segment_prompt.shape[-1] == 0:
|
267 |
+
print(f"Warning: empty segment detected for seg {seg}, start {start_idx}, end {end_idx}. Skipping this segment.")
|
268 |
+
continue
|
269 |
+
segment = stage2_generate(model_stage2, segment_prompt, batch_size=current_batch)
|
270 |
segments.append(segment)
|
271 |
+
if len(segments) == 0:
|
272 |
+
raise ValueError(f"No valid segments produced for {path}.")
|
273 |
output = np.concatenate(segments, axis=0)
|
274 |
+
|
275 |
# Process any remaining tokens if prompt length not fully used.
|
276 |
if output_duration * 50 != prompt.shape[-1]:
|
277 |
ending = stage2_generate(model_stage2, prompt[:, output_duration * 50:], batch_size=1)
|
278 |
output = np.concatenate([output, ending], axis=0)
|
279 |
+
|
280 |
+
# Convert Stage2 output tokens back to numpy using Stage2’s codec manipulator.
|
281 |
output = codectool_stage2.ids2npy(output)
|
282 |
+
|
283 |
+
# Fix any invalid codes
|
284 |
fixed_output = copy.deepcopy(output)
|
285 |
for i, line in enumerate(output):
|
286 |
for j, element in enumerate(line):
|
|
|
288 |
counter = Counter(line)
|
289 |
most_common = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
|
290 |
fixed_output[i, j] = most_common
|
291 |
+
|
292 |
np.save(output_filename, fixed_output)
|
293 |
stage2_result.append(output_filename)
|
294 |
return stage2_result
|