asigalov61 commited on
Commit
b9122dd
·
verified ·
1 Parent(s): 8bfdb32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -99
app.py CHANGED
@@ -174,10 +174,10 @@ def load_midi(input_midi, melody_patch=-1, use_nth_note=1):
174
  @spaces.GPU
175
  def Generate_Accompaniment(input_midi,
176
  input_melody,
177
- generation_type,
178
  melody_patch,
179
  use_nth_note,
180
- model_temperature
 
181
  ):
182
 
183
  #===============================================================================
@@ -220,48 +220,11 @@ def Generate_Accompaniment(input_midi,
220
 
221
  #===============================================================================
222
 
223
- def generate_block_seq(input_seq, trg_dtime, temperature=0.9):
224
-
225
- inp_seq = copy.deepcopy(input_seq)
226
-
227
- block_seq = []
228
-
229
- cur_time = 0
230
-
231
- while cur_time < trg_dtime:
232
-
233
- x = torch.LongTensor(inp_seq).cuda()
234
-
235
- with ctx:
236
- out = model.generate(x,
237
- 1,
238
- temperature=temperature,
239
- return_prime=False,
240
- verbose=False)
241
-
242
- y = out.tolist()[0][0]
243
-
244
- if y < 128:
245
- cur_time += y
246
-
247
- inp_seq.append(y)
248
- block_seq.append(y)
249
-
250
- if cur_time != trg_dtime:
251
- return []
252
-
253
- else:
254
- return block_seq
255
-
256
- #===============================================================================
257
-
258
  print('=' * 70)
259
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
260
  start_time = reqtime.time()
261
  print('=' * 70)
262
 
263
-
264
-
265
  print('=' * 70)
266
  print('Requested settings:')
267
  print('=' * 70)
@@ -270,10 +233,10 @@ def Generate_Accompaniment(input_midi,
270
  fn1 = fn.split('.')[0]
271
  print('Input MIDI file name:', fn)
272
  print('Input sample melody:', input_melody)
273
- print('Generation type:', generation_type)
274
  print('Source melody patch:', melody_patch)
275
  print('Use nth melody note:', use_nth_note)
276
  print('Model temperature:', model_temperature)
 
277
 
278
  print('=' * 70)
279
 
@@ -315,58 +278,7 @@ def Generate_Accompaniment(input_midi,
315
 
316
  #==================================================================
317
 
318
- if generation_type == 'Guided':
319
-
320
- input_seq = []
321
-
322
- input_seq.extend(start_score_seq)
323
- input_seq.extend(score_list[0][0])
324
-
325
- block_seq_lens = []
326
-
327
- idx = 0
328
-
329
- max_retries = 2
330
- mrt = 0
331
-
332
- while idx < len(score_list)-1:
333
-
334
- if idx % 10 == 0:
335
- print('Generating', idx, 'block')
336
-
337
- input_seq.extend(score_list[idx][1])
338
-
339
- block_seq = []
340
-
341
- for _ in range(max_retries):
342
-
343
- block_seq = generate_block_seq(input_seq, score_list[idx+1][0][0])
344
-
345
- if block_seq:
346
- break
347
-
348
- if block_seq:
349
- input_seq.extend(block_seq)
350
- block_seq_lens.append(len(block_seq))
351
- idx += 1
352
- mrt = 0
353
-
354
- else:
355
-
356
- if block_seq_lens:
357
- input_seq = input_seq[:-(block_seq_lens[-1]+2)]
358
- block_seq_lens.pop()
359
- idx -= 1
360
- mrt += 1
361
-
362
- else:
363
- break
364
-
365
- if mrt == max_retries:
366
- break
367
-
368
- else:
369
- input_seq = generate_full_seq(start_score_seq, temperature=model_temperature)
370
 
371
  final_song = input_seq[len(start_score_seq):]
372
 
@@ -492,9 +404,9 @@ with gr.Blocks() as demo:
492
 
493
  gr.Markdown("## Generation options")
494
 
495
- generation_type = gr.Radio(["Guided", "Freestyle"], value="Guided", label="Generation type")
496
  melody_patch = gr.Slider(-1, 127, value=-1, step=1, label="Source melody MIDI patch")
497
  use_nth_note = gr.Slider(1, 8, value=1, step=1, label="Use each nth melody note")
 
498
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
499
 
500
  generate_btn = gr.Button("Generate", variant="primary")
@@ -508,10 +420,10 @@ with gr.Blocks() as demo:
508
  generate_btn.click(Generate_Accompaniment,
509
  [input_midi,
510
  input_melody,
511
- generation_type,
512
  melody_patch,
513
  use_nth_note,
514
- model_temperature
 
515
  ],
516
  [output_audio,
517
  output_plot,
@@ -520,15 +432,15 @@ with gr.Blocks() as demo:
520
  )
521
 
522
  gr.Examples(
523
- [["USSR-National-Anthem-Seed-Melody.mid", "Custom MIDI", "Freestyle", -1, 1, 0.9],
524
- ["Sparks-Fly-Seed-Melody.mid", "Custom MIDI", "Guided", -1, 1, 0.9]
525
  ],
526
  [input_midi,
527
  input_melody,
528
- generation_type,
529
  melody_patch,
530
  use_nth_note,
531
- model_temperature
 
532
  ],
533
  [output_audio,
534
  output_plot,
 
174
  @spaces.GPU
175
  def Generate_Accompaniment(input_midi,
176
  input_melody,
 
177
  melody_patch,
178
  use_nth_note,
179
+ model_temperature,
180
+ model_sampling_top_k
181
  ):
182
 
183
  #===============================================================================
 
220
 
221
  #===============================================================================
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  print('=' * 70)
224
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
225
  start_time = reqtime.time()
226
  print('=' * 70)
227
 
 
 
228
  print('=' * 70)
229
  print('Requested settings:')
230
  print('=' * 70)
 
233
  fn1 = fn.split('.')[0]
234
  print('Input MIDI file name:', fn)
235
  print('Input sample melody:', input_melody)
 
236
  print('Source melody patch:', melody_patch)
237
  print('Use nth melody note:', use_nth_note)
238
  print('Model temperature:', model_temperature)
239
+ print('Model top k':, model_sampling_top_k)
240
 
241
  print('=' * 70)
242
 
 
278
 
279
  #==================================================================
280
 
281
+ input_seq = generate_full_seq(start_score_seq, temperature=model_temperature)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  final_song = input_seq[len(start_score_seq):]
284
 
 
404
 
405
  gr.Markdown("## Generation options")
406
 
 
407
  melody_patch = gr.Slider(-1, 127, value=-1, step=1, label="Source melody MIDI patch")
408
  use_nth_note = gr.Slider(1, 8, value=1, step=1, label="Use each nth melody note")
409
+ model_sampling_top_k = gr.Slider(1, 100, value=15, step=1, label="Model sampling top k value")
410
  model_temperature = gr.Slider(0.1, 1, value=0.9, step=0.01, label="Model temperature")
411
 
412
  generate_btn = gr.Button("Generate", variant="primary")
 
420
  generate_btn.click(Generate_Accompaniment,
421
  [input_midi,
422
  input_melody,
 
423
  melody_patch,
424
  use_nth_note,
425
+ model_temperature,
426
+ model_sampling_top_k
427
  ],
428
  [output_audio,
429
  output_plot,
 
432
  )
433
 
434
  gr.Examples(
435
+ [["USSR-National-Anthem-Seed-Melody.mid", "Custom MIDI", -1, 1, 0.9, 15],
436
+ ["Sparks-Fly-Seed-Melody.mid", "Custom MIDI", "Guided", -1, 1, 0.9, 15]
437
  ],
438
  [input_midi,
439
  input_melody,
 
440
  melody_patch,
441
  use_nth_note,
442
+ model_temperature,
443
+ model_sampling_top_k
444
  ],
445
  [output_audio,
446
  output_plot,