fffiloni commited on
Commit
1c42a58
·
1 Parent(s): 05e653a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -2
app.py CHANGED
@@ -11,9 +11,18 @@ from share_btn import community_icon_html, loading_icon_html, share_js
11
  MODEL_ID = "riffusion/riffusion-model-v1"
12
  pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
13
  pipe = pipe.to("cuda")
 
 
14
 
 
15
 
16
- def predict(prompt, negative_prompt, duration):
 
 
 
 
 
 
17
  if duration == 5:
18
  width_duration=512
19
  else :
@@ -25,6 +34,13 @@ def predict(prompt, negative_prompt, duration):
25
  f.write(wav[0].getbuffer())
26
  return spec, 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
27
 
 
 
 
 
 
 
 
28
 
29
  title = """
30
  <div style="text-align: center; max-width: 500px; margin: 0 auto;">
@@ -142,6 +158,7 @@ with gr.Blocks(css=css) as demo:
142
  gr.HTML(title)
143
 
144
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
 
145
  with gr.Row():
146
  negative_prompt = gr.Textbox(label="Negative prompt")
147
  duration_input = gr.Slider(label="Duration in seconds", minimum=5, maximum=10, step=1, value=8, elem_id="duration-slider")
@@ -160,7 +177,7 @@ with gr.Blocks(css=css) as demo:
160
 
161
  gr.HTML(article)
162
 
163
- send_btn.click(predict, inputs=[prompt_input, negative_prompt, duration_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
164
  share_button.click(None, [], [], _js=share_js)
165
 
166
  demo.queue(max_size=250).launch(debug=True)
 
11
  MODEL_ID = "riffusion/riffusion-model-v1"
12
  pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
13
  pipe = pipe.to("cuda")
14
+ pipe2 = StableDiffusionImg2ImgPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
15
+ pipe2 = pipe2.to(device)
16
 
17
+ spectro_from_wav = gr.Interface.load("spaces/fffiloni/audio-to-spectrogram")
18
 
19
+ def predict(prompt, negative_prompt, audio_input, duration):
20
+ if audio_input == None :
21
+ return classic(prompt, negative_prompt, duration)
22
+ else :
23
+ return style_transfer(prompt, negative_prompt, audio_input)
24
+
25
+ def classic(prompt, negative_prompt, duration):
26
  if duration == 5:
27
  width_duration=512
28
  else :
 
34
  f.write(wav[0].getbuffer())
35
  return spec, 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
36
 
37
+ def style_transfer(prompt, negative_prompt, audio_input):
38
+ spec = spectro_from_wav(audio_input)
39
+ new_spectro = pipe(prompt=prompt, image=spec, strength=0.5, guidance_scale=7).images[0]
40
+ wav = wav_bytes_from_spectrogram_image(spec)
41
+ with open("output.wav", "wb") as f:
42
+ f.write(wav[0].getbuffer())
43
+ return new_spectro, 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
44
 
45
  title = """
46
  <div style="text-align: center; max-width: 500px; margin: 0 auto;">
 
158
  gr.HTML(title)
159
 
160
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
161
+ audio_input = gr.Audio(source="upload", type="filepath")
162
  with gr.Row():
163
  negative_prompt = gr.Textbox(label="Negative prompt")
164
  duration_input = gr.Slider(label="Duration in seconds", minimum=5, maximum=10, step=1, value=8, elem_id="duration-slider")
 
177
 
178
  gr.HTML(article)
179
 
180
+ send_btn.click(predict, inputs=[prompt_input, negative_prompt, audio_input, duration_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
181
  share_button.click(None, [], [], _js=share_js)
182
 
183
  demo.queue(max_size=250).launch(debug=True)