Update app.py
Browse files
app.py
CHANGED
@@ -31,19 +31,24 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
31 |
|
32 |
print('Loading model...')
|
33 |
|
34 |
-
SEQ_LEN = 8192
|
35 |
-
PAD_IDX =
|
36 |
-
DEVICE = 'cuda' # '
|
37 |
|
38 |
# instantiate the model
|
39 |
|
40 |
model = TransformerWrapper(
|
41 |
num_tokens = PAD_IDX+1,
|
42 |
max_seq_len = SEQ_LEN,
|
43 |
-
attn_layers = Decoder(dim = 2048,
|
|
|
|
|
|
|
|
|
|
|
44 |
)
|
45 |
|
46 |
-
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
47 |
|
48 |
model.to(DEVICE)
|
49 |
print('=' * 70)
|
@@ -51,7 +56,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
51 |
print('Loading model checkpoint...')
|
52 |
|
53 |
model.load_state_dict(
|
54 |
-
torch.load('
|
55 |
map_location=DEVICE))
|
56 |
print('=' * 70)
|
57 |
|
@@ -62,7 +67,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
62 |
else:
|
63 |
dtype = torch.float16
|
64 |
|
65 |
-
ctx = torch.amp.autocast(device_type=
|
66 |
|
67 |
print('Done!')
|
68 |
print('=' * 70)
|
|
|
31 |
|
32 |
print('Loading model...')
|
33 |
|
34 |
+
SEQ_LEN = 8192
|
35 |
+
PAD_IDX = 2239
|
36 |
+
DEVICE = 'cuda' # 'cpu'
|
37 |
|
38 |
# instantiate the model
|
39 |
|
40 |
model = TransformerWrapper(
|
41 |
num_tokens = PAD_IDX+1,
|
42 |
max_seq_len = SEQ_LEN,
|
43 |
+
attn_layers = Decoder(dim = 2048,
|
44 |
+
depth = 8,
|
45 |
+
heads = 32,
|
46 |
+
rotary_pos_emb = True,
|
47 |
+
attn_flash = True
|
48 |
+
)
|
49 |
)
|
50 |
|
51 |
+
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
|
52 |
|
53 |
model.to(DEVICE)
|
54 |
print('=' * 70)
|
|
|
56 |
print('Loading model checkpoint...')
|
57 |
|
58 |
model.load_state_dict(
|
59 |
+
torch.load('Ultimate_Chords_Progressions_Transformer_Trained_Model_LAX_5858_steps_0.4506_loss_0.8724_acc.pth',
|
60 |
map_location=DEVICE))
|
61 |
print('=' * 70)
|
62 |
|
|
|
67 |
else:
|
68 |
dtype = torch.float16
|
69 |
|
70 |
+
ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)
|
71 |
|
72 |
print('Done!')
|
73 |
print('=' * 70)
|