reach-vb HF staff commited on
Commit
ba2b8a5
·
1 Parent(s): ee5a6df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -5,16 +5,17 @@ import torch
5
 
6
  from diffusers import SpectrogramDiffusionPipeline, MidiProcessor
7
 
8
- pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion")
9
- pipe = pipe.to("cuda")
 
10
  processor = MidiProcessor()
11
 
12
 
13
  def predict(audio_file_pth):
14
- # audio = tuple (sample_rate, frames) or (sample_rate, (frames, channels))
15
 
16
- output = pipe(processor(audio_file_pth.name)[:2])
17
- audio = output.audios[0]
 
18
 
19
  return (16000, audio.ravel())
20
 
@@ -33,7 +34,7 @@ examples = []
33
  gr.Interface(
34
  fn=predict,
35
  inputs=[
36
- gr.File(file_count="single", file_types=[".mid"]),
37
  ],
38
  outputs=[
39
  gr.Audio(label="Synthesised Music", type="numpy"),
 
5
 
6
  from diffusers import SpectrogramDiffusionPipeline, MidiProcessor
7
 
8
+ pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion", torch_dtype=torch.float16).to("cuda")
9
+ pipe.enable_xformers_memory_efficient_attention()
10
+
11
  processor = MidiProcessor()
12
 
13
 
14
  def predict(audio_file_pth):
 
15
 
16
+ with torch.inference_mode():
17
+ output = pipe(processor(audio_file_pth.name)[:2])
18
+ audio = output.audios[0]
19
 
20
  return (16000, audio.ravel())
21
 
 
34
  gr.Interface(
35
  fn=predict,
36
  inputs=[
37
+ gr.File(label="Upload MIDI", file_count="single", file_types=[".mid"]),
38
  ],
39
  outputs=[
40
  gr.Audio(label="Synthesised Music", type="numpy"),