asigalov61 commited on
Commit
a173e60
·
verified ·
1 Parent(s): c4bfe55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -23,7 +23,7 @@ in_space = os.getenv("SYSTEM") == "spaces"
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
- def GenerateMIDI(num_tok, idrums, iinstr, input_top_k_value):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = time.time()
@@ -32,7 +32,7 @@ def GenerateMIDI(num_tok, idrums, iinstr, input_top_k_value):
32
  print('Req num tok:', num_tok)
33
  print('Req instr:', iinstr)
34
  print('Drums:', idrums)
35
- print('top_k:', input_top_k_value)
36
  print('-' * 70)
37
 
38
  if idrums:
@@ -127,8 +127,8 @@ def GenerateMIDI(num_tok, idrums, iinstr, input_top_k_value):
127
  with torch.inference_mode():
128
  out = model.module.generate(inp,
129
  1,
130
- filter_logits_fn=top_k,
131
- filter_kwargs={'k': input_top_k_value},
132
  temperature=0.9,
133
  return_prime=False,
134
  verbose=False)
@@ -210,13 +210,13 @@ if __name__ == "__main__":
210
  value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
211
  input_drums = gr.Checkbox(label="Add Drums", value=False, info="Add drums to the composition")
212
  input_num_tokens = gr.Slider(16, 1024, value=512, label="Number of Tokens", info="Number of tokens to generate")
213
- input_top_k_value = gr.Slider(1, 100, value=15, label="Model sampling top_k value")
214
 
215
  run_btn = gr.Button("generate", variant="primary")
216
 
217
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
218
  output_plot = gr.Plot(label='output plot')
219
  output_midi = gr.File(label="output midi", file_types=[".mid"])
220
- run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument, input_top_k_value],
221
  [output_plot, output_midi, output_audio])
222
  app.queue().launch()
 
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
+ def GenerateMIDI(num_tok, idrums, iinstr, input_top_k_ratio):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = time.time()
 
32
  print('Req num tok:', num_tok)
33
  print('Req instr:', iinstr)
34
  print('Drums:', idrums)
35
+ print('top_k:', input_top_k_ratio)
36
  print('-' * 70)
37
 
38
  if idrums:
 
127
  with torch.inference_mode():
128
  out = model.module.generate(inp,
129
  1,
130
+ filter_logits_fn = top_k,
131
+ filter_thres = input_top_k_ratio,
132
  temperature=0.9,
133
  return_prime=False,
134
  verbose=False)
 
210
  value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
211
  input_drums = gr.Checkbox(label="Add Drums", value=False, info="Add drums to the composition")
212
  input_num_tokens = gr.Slider(16, 1024, value=512, label="Number of Tokens", info="Number of tokens to generate")
213
+ input_top_k_ratio = gr.Slider(0.1, 1, value=0.95, step=0.01, label="Model sampling top_k ratio")
214
 
215
  run_btn = gr.Button("generate", variant="primary")
216
 
217
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
218
  output_plot = gr.Plot(label='output plot')
219
  output_midi = gr.File(label="output midi", file_types=[".mid"])
220
+ run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument, input_top_k_ratio],
221
  [output_plot, output_midi, output_audio])
222
  app.queue().launch()