Sourikta commited on
Commit
a8b76e1
·
verified ·
1 Parent(s): 1181c71

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +417 -0
app.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import spaces
4
+ import json
5
+ import re
6
+ import random
7
+ import numpy as np
8
+ from gradio_client import Client, handle_file
9
+ hf_token = os.environ.get("HF_TOKEN")
10
+
11
+ MAX_SEED = np.iinfo(np.int32).max
12
+
13
+ def check_api(model_name):
14
+ if model_name == "MAGNet":
15
+ try :
16
+ client = Client("fffiloni/MAGNet")
17
+ return "api ready"
18
+ except :
19
+ return "api not ready yet"
20
+ elif model_name == "AudioLDM-2":
21
+ try :
22
+ client = Client("fffiloni/audioldm2-text2audio-text2music-API", hf_token=hf_token)
23
+ return "api ready"
24
+ except :
25
+ return "api not ready yet"
26
+ elif model_name == "Riffusion":
27
+ try :
28
+ client = Client("fffiloni/spectrogram-to-music")
29
+ return "api ready"
30
+ except :
31
+ return "api not ready yet"
32
+ elif model_name == "Mustango":
33
+ try :
34
+ client = Client("fffiloni/mustango-API", hf_token=hf_token)
35
+ return "api ready"
36
+ except :
37
+ return "api not ready yet"
38
+ elif model_name == "MusicGen":
39
+ try :
40
+ client = Client("https://facebook-musicgen.hf.space/")
41
+ return "api ready"
42
+ except :
43
+ return "api not ready yet"
44
+ elif model_name == "Stable Audio Open":
45
+ try:
46
+ client = Client("fffiloni/Stable-Audio-Open-A10", hf_token=hf_token)
47
+ return "api ready"
48
+ except:
49
+ return "api not ready yet"
50
+
51
+
52
+ from moviepy.editor import VideoFileClip
53
+ from moviepy.audio.AudioClip import AudioClip
54
+
55
+ def extract_audio(video_in):
56
+ input_video = video_in
57
+ output_audio = 'audio.wav'
58
+
59
+ # Open the video file and extract the audio
60
+ video_clip = VideoFileClip(input_video)
61
+ audio_clip = video_clip.audio
62
+
63
+ # Save the audio as a .wav file
64
+ audio_clip.write_audiofile(output_audio, fps=44100) # Use 44100 Hz as the sample rate for .wav files
65
+ print("Audio extraction complete.")
66
+
67
+ return 'audio.wav'
68
+
69
+
70
+
71
+ def get_caption(image_in):
72
+ kosmos2_client = Client("fffiloni/Kosmos-2-API", hf_token=hf_token)
73
+ kosmos2_result = kosmos2_client.predict(
74
+ image_input=handle_file(image_in),
75
+ text_input="Detailed",
76
+ api_name="/generate_predictions"
77
+ )
78
+ print(f"KOSMOS2 RETURNS: {kosmos2_result}")
79
+
80
+ data = kosmos2_result[1]
81
+
82
+ # Extract and combine tokens starting from the second element
83
+ sentence = ''.join(item['token'] for item in data[1:])
84
+
85
+ # Find the last occurrence of "."
86
+ #last_period_index = full_sentence.rfind('.')
87
+
88
+ # Truncate the string up to the last period
89
+ #truncated_caption = full_sentence[:last_period_index + 1]
90
+
91
+ # print(truncated_caption)
92
+ #print(f"\n—\nIMAGE CAPTION: {truncated_caption}")
93
+
94
+ return sentence
95
+
96
+ def get_caption_from_MD(image_in):
97
+ client = Client("https://vikhyatk-moondream1.hf.space/")
98
+ result = client.predict(
99
+ image_in, # filepath in 'image' Image component
100
+ "Describe precisely the image.", # str in 'Question' Textbox component
101
+ api_name="/answer_question"
102
+ )
103
+ print(result)
104
+ return result
105
+
106
+ def get_magnet(prompt):
107
+
108
+ client = Client("fffiloni/MAGNet")
109
+ result = client.predict(
110
+ model="facebook/magnet-small-10secs", # Literal['facebook/magnet-small-10secs', 'facebook/magnet-medium-10secs', 'facebook/magnet-small-30secs', 'facebook/magnet-medium-30secs', 'facebook/audio-magnet-small', 'facebook/audio-magnet-medium'] in 'Model' Radio component
111
+ model_path="", # str in 'Model Path (custom models)' Textbox component
112
+ text=prompt, # str in 'Input Text' Textbox component
113
+ temperature=3, # float in 'Temperature' Number component
114
+ topp=0.9, # float in 'Top-p' Number component
115
+ max_cfg_coef=10, # float in 'Max CFG coefficient' Number component
116
+ min_cfg_coef=1, # float in 'Min CFG coefficient' Number component
117
+ decoding_steps1=20, # float in 'Decoding Steps (stage 1)' Number component
118
+ decoding_steps2=10, # float in 'Decoding Steps (stage 2)' Number component
119
+ decoding_steps3=10, # float in 'Decoding Steps (stage 3)' Number component
120
+ decoding_steps4=10, # float in 'Decoding Steps (stage 4)' Number component
121
+ span_score="prod-stride1 (new!)", # Literal['max-nonoverlap', 'prod-stride1 (new!)'] in 'Span Scoring' Radio component
122
+ api_name="/predict_full"
123
+ )
124
+ print(result)
125
+ return result[1]
126
+
127
+ def get_audioldm(prompt):
128
+ client = Client("fffiloni/audioldm2-text2audio-text2music-API", hf_token=hf_token)
129
+ seed = random.randint(0, MAX_SEED)
130
+ result = client.predict(
131
+ text=prompt, # str in 'Input text' Textbox component
132
+ negative_prompt="Low quality.", # str in 'Negative prompt' Textbox component
133
+ duration=10, # int | float (numeric value between 5 and 15) in 'Duration (seconds)' Slider component
134
+ guidance_scale=6.5, # int | float (numeric value between 0 and 7) in 'Guidance scale' Slider component
135
+ random_seed=seed, # int | float in 'Seed' Number component
136
+ n_candidates=3, # int | float (numeric value between 1 and 5) in 'Number waveforms to generate' Slider component
137
+ api_name="/text2audio"
138
+ )
139
+ print(result)
140
+
141
+ return result
142
+
143
+ def get_riffusion(prompt):
144
+ client = Client("fffiloni/spectrogram-to-music")
145
+ result = client.predict(
146
+ prompt=prompt, # str in 'Musical prompt' Textbox component
147
+ negative_prompt="", # str in 'Negative prompt' Textbox component
148
+ audio_input=None, # filepath in 'parameter_4' Audio component
149
+ duration=10, # float (numeric value between 5 and 10) in 'Duration in seconds' Slider component
150
+ api_name="/predict"
151
+ )
152
+ print(result)
153
+ return result[1]
154
+
155
+ def get_mustango(prompt):
156
+ client = Client("fffiloni/mustango-API", hf_token=hf_token)
157
+ result = client.predict(
158
+ prompt=prompt, # str in 'Prompt' Textbox component
159
+ steps=200, # float (numeric value between 100 and 200) in 'Steps' Slider component
160
+ guidance=6, # float (numeric value between 1 and 10) in 'Guidance Scale' Slider component
161
+ api_name="/predict"
162
+ )
163
+ print(result)
164
+ return result
165
+
166
+ def get_musicgen(prompt):
167
+ client = Client("https://facebook-musicgen.hf.space/")
168
+ result = client.predict(
169
+ prompt, # str in 'Describe your music' Textbox component
170
+ None, # str (filepath or URL to file) in 'File' Audio component
171
+ fn_index=0
172
+ )
173
+ print(result)
174
+ return result[1]
175
+
176
+ def get_stable_audio_open(prompt):
177
+ client = Client("fffiloni/Stable-Audio-Open-A10", hf_token=hf_token)
178
+ result = client.predict(
179
+ prompt=prompt,
180
+ seconds_total=10,
181
+ steps=100,
182
+ cfg_scale=7,
183
+ api_name="/predict"
184
+ )
185
+ print(result)
186
+ return result
187
+
188
+ import re
189
+ import torch
190
+ from transformers import pipeline
191
+
192
+ zephyr_model = "HuggingFaceH4/zephyr-7b-beta"
193
+ mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
194
+
195
+ pipe = pipeline("text-generation", model=zephyr_model, torch_dtype=torch.bfloat16, device_map="auto")
196
+
197
+ standard_sys = f"""
198
+ You are a musician AI whose job is to help users create their own music which its genre will reflect the character or scene from an image described by users.
199
+ In particular, you need to respond succintly with few musical words, in a friendly tone, write a musical prompt for a music generation model.
200
+ For example, if a user says, "a picture of a man in a black suit and tie riding a black dragon", provide immediately a musical prompt corresponding to the image description.
201
+ Immediately STOP after that. It should be EXACTLY in this format:
202
+ "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle"
203
+ """
204
+
205
+ mustango_sys = f"""
206
+ You are a musician AI whose job is to help users create their own music which its genre will reflect the character or scene from an image described by users.
207
+ In particular, you need to respond succintly with few musical words, in a friendly tone, write a musical prompt for a music generation model, you MUST include chords progression.
208
+ For example, if a user says, "a painting of three old women having tea party", provide immediately a musical prompt corresponding to the image description.
209
+ Immediately STOP after that. It should be EXACTLY in this format:
210
+ "The song is an instrumental. The song is in medium tempo with a classical guitar playing a lilting melody in accompaniment style. The song is emotional and romantic. The song is a romantic instrumental song. The chord sequence is Gm, F6, Ebm. The time signature is 4/4. This song is in Adagio. The key of this song is G minor."
211
+ """
212
+
213
+ @spaces.GPU(enable_queue=True)
214
+ def get_musical_prompt(user_prompt, chosen_model):
215
+
216
+ """
217
+ if chosen_model == "Mustango" :
218
+ agent_maker_sys = standard_sys
219
+ else :
220
+ agent_maker_sys = standard_sys
221
+ """
222
+ agent_maker_sys = standard_sys
223
+
224
+ instruction = f"""
225
+ <|system|>
226
+ {agent_maker_sys}</s>
227
+ <|user|>
228
+ """
229
+
230
+ prompt = f"{instruction.strip()}\n{user_prompt}</s>"
231
+ outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
232
+ pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>'
233
+ cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL)
234
+
235
+ print(f"SUGGESTED Musical prompt: {cleaned_text}")
236
+ return cleaned_text.lstrip("\n")
237
+
238
+ def infer(image_in, chosen_model, api_status):
239
+ if image_in == None :
240
+ raise gr.Error("Please provide an image input")
241
+
242
+ if chosen_model == [] :
243
+ raise gr.Error("Please pick a model")
244
+
245
+ if api_status == "api not ready yet" :
246
+ raise gr.Error("This model is not ready yet, you can pick another one instead :)")
247
+
248
+ gr.Info("Getting image caption with Kosmos-2...")
249
+ user_prompt = get_caption(image_in)
250
+ #user_prompt = get_caption_from_MD(image_in)
251
+
252
+ gr.Info("Building a musical prompt according to the image caption ...")
253
+ musical_prompt = get_musical_prompt(user_prompt, chosen_model)
254
+
255
+ if chosen_model == "MAGNet" :
256
+ gr.Info("Now calling MAGNet for music...")
257
+ music_o = get_magnet(musical_prompt)
258
+ elif chosen_model == "AudioLDM-2" :
259
+ gr.Info("Now calling AudioLDM-2 for music...")
260
+ music_o = get_audioldm(musical_prompt)
261
+ elif chosen_model == "Riffusion" :
262
+ gr.Info("Now calling Riffusion for music...")
263
+ music_o = get_riffusion(musical_prompt)
264
+ elif chosen_model == "Mustango" :
265
+ gr.Info("Now calling Mustango for music...")
266
+ music_o = get_mustango(musical_prompt)
267
+ elif chosen_model == "MusicGen" :
268
+ gr.Info("Now calling MusicGen for music...")
269
+ music_o = get_musicgen(musical_prompt)
270
+ elif chosen_model == "Stable Audio Open" :
271
+ gr.Info("Now calling Stable Audio Open for music...")
272
+ music_o = get_stable_audio_open(musical_prompt)
273
+
274
+ return gr.update(value=musical_prompt, interactive=True), gr.update(visible=True), music_o
275
+
276
+ def retry(chosen_model, caption):
277
+ musical_prompt = caption
278
+ music_o = None
279
+
280
+ if chosen_model == "MAGNet" :
281
+ gr.Info("Now calling MAGNet for music...")
282
+ music_o = get_magnet(musical_prompt)
283
+ elif chosen_model == "AudioLDM-2" :
284
+ gr.Info("Now calling AudioLDM-2 for music...")
285
+ music_o = get_audioldm(musical_prompt)
286
+ elif chosen_model == "Riffusion" :
287
+ gr.Info("Now calling Riffusion for music...")
288
+ music_o = get_riffusion(musical_prompt)
289
+ elif chosen_model == "Mustango" :
290
+ gr.Info("Now calling Mustango for music...")
291
+ music_o = get_mustango(musical_prompt)
292
+ elif chosen_model == "MusicGen" :
293
+ gr.Info("Now calling MusicGen for music...")
294
+ music_o = get_musicgen(musical_prompt)
295
+ elif chosen_model == "Stable Audio Open" :
296
+ gr.Info("Now calling Stable Audio Open for music...")
297
+ music_o = get_stable_audio_open(musical_prompt)
298
+
299
+ return music_o
300
+
301
+ demo_title = "Image to Music V2"
302
+ description = "Get music from a picture, compare text-to-music models"
303
+
304
+ css = """
305
+ #col-container {
306
+ margin: 0 auto;
307
+ max-width: 980px;
308
+ text-align: left;
309
+ }
310
+ #inspi-prompt textarea {
311
+ font-size: 20px;
312
+ line-height: 24px;
313
+ font-weight: 600;
314
+ }
315
+ """
316
+
317
+ with gr.Blocks(css=css) as demo:
318
+
319
+ with gr.Column(elem_id="col-container"):
320
+
321
+ gr.HTML(f"""
322
+ <h2 style="text-align: center;">{demo_title}</h2>
323
+ <p style="text-align: center;">{description}</p>
324
+ """)
325
+
326
+ with gr.Row():
327
+
328
+ with gr.Column():
329
+ image_in = gr.Image(
330
+ label = "Image reference",
331
+ type = "filepath",
332
+ elem_id = "image-in"
333
+ )
334
+
335
+ with gr.Row():
336
+
337
+ chosen_model = gr.Dropdown(
338
+ label = "Choose a model",
339
+ choices = [
340
+ #"MAGNet",
341
+ "AudioLDM-2",
342
+ "Riffusion",
343
+ "Mustango",
344
+ #"MusicGen",
345
+ "Stable Audio Open"
346
+ ],
347
+ value = None,
348
+ filterable = False
349
+ )
350
+
351
+ check_status = gr.Textbox(
352
+ label="API status",
353
+ interactive=False
354
+ )
355
+
356
+ submit_btn = gr.Button("Make music from my pic !")
357
+
358
+ gr.Examples(
359
+ examples = [
360
+ ["examples/ocean_poet.jpeg"],
361
+ ["examples/jasper_horace.jpeg"],
362
+ ["examples/summer.jpeg"],
363
+ ["examples/mona_diner.png"],
364
+ ["examples/monalisa.png"],
365
+ ["examples/santa.png"],
366
+ ["examples/winter_hiking.png"],
367
+ ["examples/teatime.jpeg"],
368
+ ["examples/news_experts.jpeg"]
369
+ ],
370
+ fn = infer,
371
+ inputs = [image_in, chosen_model],
372
+ examples_per_page = 4
373
+ )
374
+
375
+ with gr.Column():
376
+
377
+ caption = gr.Textbox(
378
+ label = "Inspirational musical prompt",
379
+ interactive = False,
380
+ elem_id = "inspi-prompt"
381
+ )
382
+
383
+ retry_btn = gr.Button("Retry with edited prompt", visible=False)
384
+
385
+ result = gr.Audio(
386
+ label = "Music"
387
+ )
388
+
389
+
390
+ chosen_model.change(
391
+ fn = check_api,
392
+ inputs = chosen_model,
393
+ outputs = check_status,
394
+ queue = False
395
+ )
396
+
397
+ retry_btn.click(
398
+ fn = retry,
399
+ inputs = [chosen_model, caption],
400
+ outputs = [result]
401
+ )
402
+
403
+ submit_btn.click(
404
+ fn = infer,
405
+ inputs = [
406
+ image_in,
407
+ chosen_model,
408
+ check_status
409
+ ],
410
+ outputs =[
411
+ caption,
412
+ retry_btn,
413
+ result
414
+ ]
415
+ )
416
+
417
+ demo.queue(max_size=16).launch(show_api=False, show_error=True)