asigalov61 commited on
Commit
26f179f
·
verified ·
1 Parent(s): b731f7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -65,6 +65,8 @@ SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
65
 
66
  MAX_MELODY_NOTES = 64
67
 
 
 
68
  #==================================================================================
69
 
70
  print('=' * 70)
@@ -184,7 +186,7 @@ def Generate_Accompaniment(input_midi,
184
 
185
  #===============================================================================
186
 
187
- def generate_full_seq(input_seq, max_toks=3072, temperature=0.9, verbose=True):
188
 
189
  seq_abs_run_time = sum([t for t in input_seq if t < 128])
190
 
@@ -205,6 +207,8 @@ def Generate_Accompaniment(input_midi,
205
  with ctx:
206
  out = model.generate(x,
207
  1,
 
 
208
  temperature=temperature,
209
  return_prime=False,
210
  verbose=False)
@@ -217,6 +221,9 @@ def Generate_Accompaniment(input_midi,
217
  full_seq.append(y)
218
 
219
  toks_counter += 1
 
 
 
220
 
221
  return full_seq
222
 
@@ -281,7 +288,11 @@ def Generate_Accompaniment(input_midi,
281
 
282
  #==================================================================
283
 
284
- input_seq = generate_full_seq(start_score_seq, temperature=model_temperature)
 
 
 
 
285
 
286
  final_song = input_seq[len(start_score_seq):]
287
 
 
65
 
66
  MAX_MELODY_NOTES = 64
67
 
68
+ MAX_GEN_TOKS = 3072
69
+
70
  #==================================================================================
71
 
72
  print('=' * 70)
 
186
 
187
  #===============================================================================
188
 
189
+ def generate_full_seq(input_seq, max_toks=3072, temperature=0.9, top_k_value=15, verbose=True):
190
 
191
  seq_abs_run_time = sum([t for t in input_seq if t < 128])
192
 
 
207
  with ctx:
208
  out = model.generate(x,
209
  1,
210
+ filter_logits_fn=top_k,
211
+ filter_kwargs={'thres': top_k_value},
212
  temperature=temperature,
213
  return_prime=False,
214
  verbose=False)
 
221
  full_seq.append(y)
222
 
223
  toks_counter += 1
224
+
225
+ if toks_counter == max_toks:
226
+ return full_seq
227
 
228
  return full_seq
229
 
 
288
 
289
  #==================================================================
290
 
291
+ input_seq = generate_full_seq(start_score_seq,
292
+ max_toks=MAX_GEN_TOKS,
293
+ temperature=model_temperature,
294
+ top_k_value=model_sampling_top_k,
295
+ )
296
 
297
  final_song = input_seq[len(start_score_seq):]
298