Update app.py
Browse files
app.py
CHANGED
@@ -55,6 +55,7 @@ def Generate_Chords(input_midi, input_num_prime_chords, input_num_gen_chords, in
|
|
55 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
|
56 |
|
57 |
model.to(DEVICE)
|
|
|
58 |
print('Done!')
|
59 |
print('=' * 70)
|
60 |
|
@@ -63,10 +64,12 @@ def Generate_Chords(input_midi, input_num_prime_chords, input_num_gen_chords, in
|
|
63 |
model.load_state_dict(
|
64 |
torch.load('Ultimate_Chords_Progressions_Transformer_Trained_Model_LAX_5858_steps_0.4506_loss_0.8724_acc.pth',
|
65 |
map_location=DEVICE))
|
|
|
|
|
|
|
66 |
print('Done!')
|
67 |
print('=' * 70)
|
68 |
|
69 |
-
model.eval()
|
70 |
|
71 |
if DEVICE == 'cpu':
|
72 |
dtype = torch.bfloat16
|
|
|
55 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
|
56 |
|
57 |
model.to(DEVICE)
|
58 |
+
|
59 |
print('Done!')
|
60 |
print('=' * 70)
|
61 |
|
|
|
64 |
model.load_state_dict(
|
65 |
torch.load('Ultimate_Chords_Progressions_Transformer_Trained_Model_LAX_5858_steps_0.4506_loss_0.8724_acc.pth',
|
66 |
map_location=DEVICE))
|
67 |
+
|
68 |
+
model.eval()
|
69 |
+
|
70 |
print('Done!')
|
71 |
print('=' * 70)
|
72 |
|
|
|
73 |
|
74 |
if DEVICE == 'cpu':
|
75 |
dtype = torch.bfloat16
|