Update app.py
Browse files
app.py
CHANGED
@@ -28,8 +28,9 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
28 |
print('=' * 70)
|
29 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
30 |
start_time = reqtime.time()
|
31 |
-
|
32 |
-
print('
|
|
|
33 |
|
34 |
SEQ_LEN = 8192
|
35 |
PAD_IDX = 2239
|
@@ -51,6 +52,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
51 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
|
52 |
|
53 |
model.to(DEVICE)
|
|
|
54 |
print('=' * 70)
|
55 |
|
56 |
print('Loading model checkpoint...')
|
@@ -58,6 +60,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
58 |
model.load_state_dict(
|
59 |
torch.load('Ultimate_Chords_Progressions_Transformer_Trained_Model_LAX_5858_steps_0.4506_loss_0.8724_acc.pth',
|
60 |
map_location=DEVICE))
|
|
|
61 |
print('=' * 70)
|
62 |
|
63 |
model.eval()
|
@@ -65,9 +68,9 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
65 |
if DEVICE == 'cpu':
|
66 |
dtype = torch.bfloat16
|
67 |
else:
|
68 |
-
dtype = torch.
|
69 |
|
70 |
-
ctx = torch.amp.autocast(device_type=
|
71 |
|
72 |
print('Done!')
|
73 |
print('=' * 70)
|
@@ -77,12 +80,12 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
|
|
77 |
|
78 |
input_num_tokens = max(4, min(128, input_num_tokens))
|
79 |
|
80 |
-
print('
|
81 |
print('Input file name:', fn)
|
82 |
print('Req num toks:', input_num_tokens)
|
83 |
print('Conditioning type:', input_conditioning_type)
|
84 |
print('Strip notes:', input_strip_notes)
|
85 |
-
print('
|
86 |
|
87 |
#===============================================================================
|
88 |
raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
|
|
|
28 |
print('=' * 70)
|
29 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
30 |
start_time = reqtime.time()
|
31 |
+
|
32 |
+
print('=' * 70)
|
33 |
+
print('Instantiating the model...')
|
34 |
|
35 |
SEQ_LEN = 8192
|
36 |
PAD_IDX = 2239
|
|
|
52 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
|
53 |
|
54 |
model.to(DEVICE)
|
55 |
+
print('Done!')
|
56 |
print('=' * 70)
|
57 |
|
58 |
print('Loading model checkpoint...')
|
|
|
60 |
model.load_state_dict(
|
61 |
torch.load('Ultimate_Chords_Progressions_Transformer_Trained_Model_LAX_5858_steps_0.4506_loss_0.8724_acc.pth',
|
62 |
map_location=DEVICE))
|
63 |
+
print('Done!')
|
64 |
print('=' * 70)
|
65 |
|
66 |
model.eval()
|
|
|
68 |
if DEVICE == 'cpu':
|
69 |
dtype = torch.bfloat16
|
70 |
else:
|
71 |
+
dtype = torch.bfloat16
|
72 |
|
73 |
+
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
|
74 |
|
75 |
print('Done!')
|
76 |
print('=' * 70)
|
|
|
80 |
|
81 |
input_num_tokens = max(4, min(128, input_num_tokens))
|
82 |
|
83 |
+
print('=' * 70)
|
84 |
print('Input file name:', fn)
|
85 |
print('Req num toks:', input_num_tokens)
|
86 |
print('Conditioning type:', input_conditioning_type)
|
87 |
print('Strip notes:', input_strip_notes)
|
88 |
+
print('=' * 70)
|
89 |
|
90 |
#===============================================================================
|
91 |
raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
|