Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
-
|
2 |
-
# https://huggingface.co/spaces/asigalov61/
|
3 |
-
|
4 |
|
5 |
import time as reqtime
|
6 |
import datetime
|
7 |
from pytz import timezone
|
8 |
|
|
|
9 |
import re
|
10 |
import tqdm
|
11 |
|
@@ -21,29 +22,113 @@ import matplotlib.pyplot as plt
|
|
21 |
|
22 |
#=====================================================================================
|
23 |
|
24 |
-
|
|
|
|
|
25 |
|
26 |
-
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
tokens.extend([385])
|
32 |
|
33 |
-
|
34 |
-
|
35 |
|
36 |
-
|
37 |
-
tokens.extend([387])
|
38 |
|
39 |
-
|
|
|
40 |
|
41 |
#====================================================================================
|
42 |
|
43 |
-
def
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
print('=' * 70)
|
49 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
|
|
1 |
+
#==================================================================
|
2 |
+
# https://huggingface.co/spaces/asigalov61/Popular-Hook-Transformer
|
3 |
+
#==================================================================
|
4 |
|
5 |
import time as reqtime
|
6 |
import datetime
|
7 |
from pytz import timezone
|
8 |
|
9 |
+
import statistics
|
10 |
import re
|
11 |
import tqdm
|
12 |
|
|
|
22 |
|
23 |
#=====================================================================================
|
24 |
|
25 |
+
print('=' * 70)
|
26 |
+
print('Popular Hook Transformer')
|
27 |
+
print('=' * 70)
|
28 |
|
29 |
+
print('Loading Popular Hook Transformer training data...')
|
30 |
|
31 |
+
#====================================================================================
|
32 |
+
|
33 |
+
SEQ_LEN = 512
|
34 |
+
PAD_IDX = 918
|
35 |
+
DEVICE = 'cpu'
|
36 |
+
|
37 |
+
#====================================================================================
|
38 |
+
|
39 |
+
def str_strip(string):
|
40 |
+
return re.sub(r'[^A-Za-z-]+', '', string).rstrip('-')
|
41 |
+
|
42 |
+
def mode_time(seq):
|
43 |
+
return statistics.mode([t for t in seq if 0 < t < 128])
|
44 |
+
|
45 |
+
def mode_dur(seq):
|
46 |
+
return statistics.mode([t-128 for t in seq if 128 < t < 256])
|
47 |
+
|
48 |
+
def mode_pitch(seq):
|
49 |
+
return statistics.mode([t % 128 for t in seq if 256 < t < 512])
|
50 |
+
|
51 |
+
parts_dict = sorted(set([str_strip(s[2]).rstrip('-') for s in melody_chords_f]))
|
52 |
+
|
53 |
+
train_data = []
|
54 |
+
|
55 |
+
for m in tqdm.tqdm(melody_chords_f):
|
56 |
+
|
57 |
+
if 64 < len(m[5]) < 506:
|
58 |
+
|
59 |
+
for tv in range(-3, 3):
|
60 |
+
|
61 |
+
part = str_strip(m[2])
|
62 |
+
part_tok = parts_dict.index(part)
|
63 |
+
|
64 |
+
score = [t+tv if 256 < t < 512 else t for t in m[5]]
|
65 |
+
|
66 |
+
seq = [916] + [part_tok+512, mode_time(score)+532, mode_dur(score)+660, mode_pitch(score)+tv+788]
|
67 |
+
|
68 |
+
seq += score
|
69 |
+
|
70 |
+
seq += [917]
|
71 |
+
|
72 |
+
seq = seq + [PAD_IDX] * (SEQ_LEN - len(seq))
|
73 |
+
|
74 |
+
train_data.append(seq)
|
75 |
+
|
76 |
+
#====================================================================================
|
77 |
+
|
78 |
+
print('Done!')
|
79 |
+
print('=' * 70)
|
80 |
+
print('All data is good:', len(max(train_data, key=len)) == len(min(train_data, key=len)))
|
81 |
+
print('=' * 70)
|
82 |
+
print('Randomizing training data...')
|
83 |
+
random.shuffle(train_data)
|
84 |
+
print('Done!')
|
85 |
+
print('=' * 70)
|
86 |
+
print('Total length of training data:', len(train_data))
|
87 |
+
print('=' * 70)
|
88 |
+
|
89 |
+
#====================================================================================
|
90 |
+
|
91 |
+
print('Loading Popular Hook Transformer pre-trained model...')
|
92 |
+
print('=' * 70)
|
93 |
+
|
94 |
+
print('Instantiating model...')
|
95 |
+
|
96 |
+
model = TransformerWrapper(
|
97 |
+
num_tokens = PAD_IDX+1,
|
98 |
+
max_seq_len = SEQ_LEN,
|
99 |
+
attn_layers = Decoder(dim = 1024,
|
100 |
+
depth = 4,
|
101 |
+
heads = 32,
|
102 |
+
rotary_pos_emb = True,
|
103 |
+
attn_flash = True
|
104 |
+
)
|
105 |
+
)
|
106 |
+
|
107 |
+
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
|
108 |
+
|
109 |
+
print('=' * 70)
|
110 |
+
print('Loading model checkpoint...')
|
111 |
+
|
112 |
+
model_path = 'Popular_Hook_Transformer_Small_Trained_Model_10869_steps_0.2308_loss_0.9252_acc.pth'
|
113 |
+
|
114 |
+
model.load_state_dict(torch.load(model_path))
|
115 |
|
116 |
+
print('=' * 70)
|
|
|
117 |
|
118 |
+
model.to(DEVICE)
|
119 |
+
model.eval()
|
120 |
|
121 |
+
ctx = torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16)
|
|
|
122 |
|
123 |
+
print('Done!')
|
124 |
+
print('=' * 70)
|
125 |
|
126 |
#====================================================================================
|
127 |
|
128 |
+
def Generate_POP_Section(input_parsons_code,
|
129 |
+
input_first_note_duration,
|
130 |
+
iinput_first_note_MIDI_pitch
|
131 |
+
):
|
132 |
|
133 |
print('=' * 70)
|
134 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|