asigalov61 commited on
Commit
07c6c95
·
verified ·
1 Parent(s): 55d1ad7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -17
app.py CHANGED
@@ -1,11 +1,12 @@
1
- #=========================================================================
2
- # https://huggingface.co/spaces/asigalov61/Parsons-Code-Melody-Transformer
3
- #=========================================================================
4
 
5
  import time as reqtime
6
  import datetime
7
  from pytz import timezone
8
 
 
9
  import re
10
  import tqdm
11
 
@@ -21,29 +22,113 @@ import matplotlib.pyplot as plt
21
 
22
  #=====================================================================================
23
 
24
- def parsons_code_to_tokens(parsons_code_str):
 
 
25
 
26
- tokens = [388]
27
 
28
- for chr in parsons_code_str[1:]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- if chr == 'D':
31
- tokens.extend([385])
32
 
33
- elif chr == 'R':
34
- tokens.extend([386])
35
 
36
- elif chr == 'U':
37
- tokens.extend([387])
38
 
39
- return tokens
 
40
 
41
  #====================================================================================
42
 
43
- def Generate_Melody(input_parsons_code,
44
- input_first_note_duration,
45
- iinput_first_note_MIDI_pitch
46
- ):
47
 
48
  print('=' * 70)
49
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
 
1
+ #==================================================================
2
+ # https://huggingface.co/spaces/asigalov61/Popular-Hook-Transformer
3
+ #==================================================================
4
 
5
  import time as reqtime
6
  import datetime
7
  from pytz import timezone
8
 
9
+ import statistics
10
  import re
11
  import tqdm
12
 
 
22
 
23
  #=====================================================================================
24
 
25
+ print('=' * 70)
26
+ print('Popular Hook Transformer')
27
+ print('=' * 70)
28
 
29
+ print('Loading Popular Hook Transformer training data...')
30
 
31
+ #====================================================================================
32
+
33
+ SEQ_LEN = 512
34
+ PAD_IDX = 918
35
+ DEVICE = 'cpu'
36
+
37
+ #====================================================================================
38
+
39
+ def str_strip(string):
40
+ return re.sub(r'[^A-Za-z-]+', '', string).rstrip('-')
41
+
42
+ def mode_time(seq):
43
+ return statistics.mode([t for t in seq if 0 < t < 128])
44
+
45
+ def mode_dur(seq):
46
+ return statistics.mode([t-128 for t in seq if 128 < t < 256])
47
+
48
+ def mode_pitch(seq):
49
+ return statistics.mode([t % 128 for t in seq if 256 < t < 512])
50
+
51
+ parts_dict = sorted(set([str_strip(s[2]).rstrip('-') for s in melody_chords_f]))
52
+
53
+ train_data = []
54
+
55
+ for m in tqdm.tqdm(melody_chords_f):
56
+
57
+ if 64 < len(m[5]) < 506:
58
+
59
+ for tv in range(-3, 3):
60
+
61
+ part = str_strip(m[2])
62
+ part_tok = parts_dict.index(part)
63
+
64
+ score = [t+tv if 256 < t < 512 else t for t in m[5]]
65
+
66
+ seq = [916] + [part_tok+512, mode_time(score)+532, mode_dur(score)+660, mode_pitch(score)+tv+788]
67
+
68
+ seq += score
69
+
70
+ seq += [917]
71
+
72
+ seq = seq + [PAD_IDX] * (SEQ_LEN - len(seq))
73
+
74
+ train_data.append(seq)
75
+
76
+ #====================================================================================
77
+
78
+ print('Done!')
79
+ print('=' * 70)
80
+ print('All data is good:', len(max(train_data, key=len)) == len(min(train_data, key=len)))
81
+ print('=' * 70)
82
+ print('Randomizing training data...')
83
+ random.shuffle(train_data)
84
+ print('Done!')
85
+ print('=' * 70)
86
+ print('Total length of training data:', len(train_data))
87
+ print('=' * 70)
88
+
89
+ #====================================================================================
90
+
91
+ print('Loading Popular Hook Transformer pre-trained model...')
92
+ print('=' * 70)
93
+
94
+ print('Instantiating model...')
95
+
96
+ model = TransformerWrapper(
97
+ num_tokens = PAD_IDX+1,
98
+ max_seq_len = SEQ_LEN,
99
+ attn_layers = Decoder(dim = 1024,
100
+ depth = 4,
101
+ heads = 32,
102
+ rotary_pos_emb = True,
103
+ attn_flash = True
104
+ )
105
+ )
106
+
107
+ model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
108
+
109
+ print('=' * 70)
110
+ print('Loading model checkpoint...')
111
+
112
+ model_path = 'Popular_Hook_Transformer_Small_Trained_Model_10869_steps_0.2308_loss_0.9252_acc.pth'
113
+
114
+ model.load_state_dict(torch.load(model_path))
115
 
116
+ print('=' * 70)
 
117
 
118
+ model.to(DEVICE)
119
+ model.eval()
120
 
121
+ ctx = torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16)
 
122
 
123
+ print('Done!')
124
+ print('=' * 70)
125
 
126
  #====================================================================================
127
 
128
+ def Generate_POP_Section(input_parsons_code,
129
+ input_first_note_duration,
130
+ iinput_first_note_MIDI_pitch
131
+ ):
132
 
133
  print('=' * 70)
134
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))