asigalov61 commited on
Commit
cee38fa
·
verified ·
1 Parent(s): 8c0a088

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -1
app.py CHANGED
@@ -9,6 +9,7 @@ print('=' * 70)
9
  print('Loading core Guided Accompaniment Transformer modules...')
10
 
11
  import os
 
12
 
13
  import time as reqtime
14
  import datetime
@@ -155,6 +156,81 @@ def Generate_Accompaniment(input_midi,
155
  ):
156
 
157
  #===============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  print('=' * 70)
160
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
@@ -178,9 +254,10 @@ def Generate_Accompaniment(input_midi,
178
 
179
  score, score_list = load_midi(input_midi.name)
180
 
 
 
181
  #==================================================================
182
 
183
- print('Sample score events', score[:12])
184
  print('=' * 70)
185
  print('Generating...')
186
 
 
9
  print('Loading core Guided Accompaniment Transformer modules...')
10
 
11
  import os
12
+ import copy
13
 
14
  import time as reqtime
15
  import datetime
 
156
  ):
157
 
158
  #===============================================================================
159
+
160
+ def generate_full_seq(input_seq, temperature=0.9, verbose=True):
161
+
162
+ seq_abs_run_time = sum([t for t in input_seq if t < 128])
163
+
164
+ cur_time = 0
165
+
166
+ full_seq = input_seq
167
+
168
+ toks_counter = 0
169
+
170
+ while cur_time < seq_abs_run_time:
171
+
172
+ if verbose:
173
+ if toks_counter % 128 == 0:
174
+ print('Generated', toks_counter, 'tokens')
175
+
176
+ x = torch.LongTensor(full_seq).cuda()
177
+
178
+ with ctx:
179
+ out = model.generate(x,
180
+ 1,
181
+ temperature=temperature,
182
+ return_prime=False,
183
+ verbose=False)
184
+
185
+ y = out.tolist()[0][0]
186
+
187
+ if y < 128:
188
+ cur_time += y
189
+
190
+ full_seq.append(y)
191
+
192
+ toks_counter += 1
193
+
194
+ return full_seq
195
+
196
+ #===============================================================================
197
+
198
+ def generate_block_seq(input_seq, trg_dtime, temperature=0.9):
199
+
200
+ cur_time = 0
201
+
202
+ block_seq = [128]
203
+
204
+ while cur_time != trg_dtime and len(block_seq) < 2 and block_seq[-1] > 127:
205
+
206
+ inp_seq = copy.deepcopy(input_seq)
207
+
208
+ block_seq = []
209
+
210
+ cur_time = 0
211
+
212
+ while cur_time < trg_dtime:
213
+
214
+ x = torch.LongTensor(inp_seq).cuda()
215
+
216
+ with ctx:
217
+ out = model.generate(x,
218
+ 1,
219
+ temperature=temperature,
220
+ return_prime=False,
221
+ verbose=False)
222
+
223
+ y = out.tolist()[0][0]
224
+
225
+ if y < 128:
226
+ cur_time += y
227
+
228
+ inp_seq.append(y)
229
+ block_seq.append(y)
230
+
231
+ return block_seq
232
+
233
+ #===============================================================================
234
 
235
  print('=' * 70)
236
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
 
254
 
255
  score, score_list = load_midi(input_midi.name)
256
 
257
+ print('Sample score events', score[:12])
258
+
259
  #==================================================================
260
 
 
261
  print('=' * 70)
262
  print('Generating...')
263