Update app.py
Browse files
app.py
CHANGED
@@ -225,7 +225,7 @@ def generate_music(
|
|
225 |
print(
|
226 |
f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
|
227 |
input_ids = input_ids[:, -(max_context):]
|
228 |
-
with torch.inference_mode(), torch.autocast(device_type=
|
229 |
output_seq = model.generate(
|
230 |
input_ids=input_ids,
|
231 |
max_new_tokens=max_new_tokens,
|
|
|
225 |
print(
|
226 |
f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
|
227 |
input_ids = input_ids[:, -(max_context):]
|
228 |
+
with torch.inference_mode(), torch.autocast(device_type=device, dtype=torch.float16):
|
229 |
output_seq = model.generate(
|
230 |
input_ids=input_ids,
|
231 |
max_new_tokens=max_new_tokens,
|