Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,15 @@
|
|
|
|
1 |
from gradio_client import Client
|
2 |
import torch
|
3 |
import nltk # we'll use this to split into sentences
|
4 |
import numpy as np
|
5 |
from transformers import BarkModel, AutoProcessor
|
6 |
nltk.download('punkt')
|
|
|
7 |
import gradio as gr
|
|
|
|
|
|
|
8 |
|
9 |
def _grab_best_device(use_gpu=True):
|
10 |
if torch.cuda.device_count() > 0 and use_gpu:
|
@@ -15,11 +20,9 @@ def _grab_best_device(use_gpu=True):
|
|
15 |
|
16 |
device = _grab_best_device()
|
17 |
|
18 |
-
|
19 |
-
BATCH_SIZE = 8
|
20 |
SYST_PROMPT="""You're the storyteller, crafting a short tale for young listeners. Please abide by these guidelines:
|
21 |
-
- Keep your sentences concise and easy to understand.
|
22 |
-
- There should be only the narrator speaking.
|
23 |
|
24 |
#story_prompt = "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson."
|
25 |
story_prompt = "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom."
|
@@ -27,10 +30,24 @@ temperature = 0.9
|
|
27 |
top_p = 0.6
|
28 |
repetition_penalty = 1.2
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
image_seed = 9
|
35 |
|
36 |
processor = AutoProcessor.from_pretrained("suno/bark")
|
@@ -41,29 +58,97 @@ voice_preset = "v2/en_speaker_6"
|
|
41 |
|
42 |
# convert to bettertransformer
|
43 |
model = model.to_bettertransformer()
|
|
|
44 |
|
45 |
# enable CPU offload
|
46 |
model.enable_cpu_offload()
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
|
50 |
def generate_audio_and_image(story_prompt, voice_preset=voice_preset):
|
51 |
|
52 |
|
53 |
-
story =
|
54 |
-
story_prompt,
|
55 |
-
SYST_PROMPT,
|
56 |
-
temperature,
|
57 |
-
4096,
|
58 |
-
temperature,
|
59 |
-
repetition_penalty,
|
60 |
-
api_name="/chat"
|
61 |
-
)
|
62 |
|
|
|
63 |
|
64 |
model_input = story.replace("\n", " ").strip()
|
65 |
model_input = nltk.sent_tokenize(model_input)
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
pieces = []
|
68 |
for i in range(0, len(model_input), BATCH_SIZE):
|
69 |
inputs = model_input[BATCH_SIZE*i:min(BATCH_SIZE*(i+1), len(model_input))]
|
@@ -71,35 +156,26 @@ def generate_audio_and_image(story_prompt, voice_preset=voice_preset):
|
|
71 |
if len(inputs) != 0:
|
72 |
inputs = processor(inputs, voice_preset=voice_preset)
|
73 |
|
74 |
-
speech_output = model.generate(**inputs.to(device)
|
75 |
|
|
|
|
|
|
|
76 |
pieces += [*speech_output, silence.copy()]
|
77 |
|
78 |
-
|
79 |
-
# story_prompt+image_positive_prompt, # str in 'parameter_11' Textbox component
|
80 |
-
# image_negative_prompt, # str in 'parameter_12' Textbox component
|
81 |
-
# "absolutereality_v181.safetensors [3d9d4d2b]", # str (Option from: ['absolutereality_V16.safetensors [37db0fc3]', 'absolutereality_v181.safetensors [3d9d4d2b]', 'analog-diffusion-1.0.ckpt [9ca13f02]', 'anythingv3_0-pruned.ckpt [2700c435]', 'anything-v4.5-pruned.ckpt [65745d25]', 'anythingV5_PrtRE.safetensors [893e49b9]', 'AOM3A3_orangemixs.safetensors [9600da17]', 'childrensStories_v13D.safetensors [9dfaabcb]', 'childrensStories_v1SemiReal.safetensors [a1c56dbb]', 'childrensStories_v1ToonAnime.safetensors [2ec7b88b]', 'cyberrealistic_v33.safetensors [82b0d085]', 'deliberate_v2.safetensors [10ec4b29]', 'deliberate_v3.safetensors [afd9d2d4]', 'dreamlike-anime-1.0.safetensors [4520e090]', 'dreamlike-diffusion-1.0.safetensors [5c9fd6e0]', 'dreamlike-photoreal-2.0.safetensors [fdcf65e7]', 'dreamshaper_6BakedVae.safetensors [114c8abb]', 'dreamshaper_7.safetensors [5cf5ae06]', 'dreamshaper_8.safetensors [9d40847d]', 'edgeOfRealism_eorV20.safetensors [3ed5de15]', 'EimisAnimeDiffusion_V1.ckpt [4f828a15]', 'elldreths-vivid-mix.safetensors [342d9d26]', 'epicrealism_naturalSinRC1VAE.safetensors [90a4c676]', 'ICantBelieveItsNotPhotography_seco.safetensors [4e7a3dfd]', 'juggernaut_aftermath.safetensors [5e20c455]', 'lyriel_v16.safetensors [68fceea2]', 'mechamix_v10.safetensors [ee685731]', 'meinamix_meinaV9.safetensors [2ec66ab0]', 'meinamix_meinaV11.safetensors [b56ce717]', 'openjourney_V4.ckpt [ca2f377f]', 'portraitplus_V1.0.safetensors [1400e684]', 'Realistic_Vision_V1.4-pruned-fp16.safetensors [8d21810b]', 'Realistic_Vision_V2.0.safetensors [79587710]', 'Realistic_Vision_V4.0.safetensors [29a7afaa]', 'Realistic_Vision_V5.0.safetensors [614d1063]', 'redshift_diffusion-V10.safetensors [1400e684]', 'revAnimated_v122.safetensors [3f4fefd9]', 'rundiffusionFX25D_v10.safetensors [cd12b0ee]', 'rundiffusionFX_v10.safetensors [cd4e694d]', 'sdv1_4.ckpt [7460a6fa]', 'v1-5-pruned-emaonly.safetensors [d7049739]', 'shoninsBeautiful_v10.safetensors [25d8c546]', 'theallys-mix-ii-churned.safetensors [5d9225a4]', 'timeless-1.0.ckpt [7c4971d4]', 'toonyou_beta6.safetensors [980f6b15]'])
|
82 |
-
# 25,
|
83 |
-
# "Euler a",
|
84 |
-
# 7,
|
85 |
-
# 512,
|
86 |
-
# 512,
|
87 |
-
# image_seed,
|
88 |
-
# "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png,https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", # str (path to directory with images and a file associating images with captions called captions.json)
|
89 |
-
# fn_index=0
|
90 |
-
#)
|
91 |
-
|
92 |
|
93 |
-
#
|
|
|
94 |
|
95 |
-
return story, (sampling_rate, np.concatenate(pieces))
|
96 |
|
97 |
|
98 |
|
99 |
|
100 |
# Gradio blocks demo
|
101 |
with gr.Blocks() as demo_blocks:
|
102 |
-
gr.Markdown("""<h1 align="center">🐶Children story
|
103 |
gr.HTML("""<h3 style="text-align:center;">📢Audio Streaming powered by Gradio (v3.40.0 onwards)🦾! </h3>""")
|
104 |
with gr.Group():
|
105 |
with gr.Row():
|
@@ -114,11 +190,16 @@ with gr.Blocks() as demo_blocks:
|
|
114 |
|
115 |
with gr.Row():
|
116 |
btn = gr.Button("Create a story")
|
117 |
-
|
|
|
|
|
|
|
118 |
with gr.Row():
|
119 |
out_audio = gr.Audio(
|
120 |
streaming=False, autoplay=True) # needed to stream output audio
|
121 |
out_text = gr.Text()
|
122 |
-
btn.click(generate_audio_and_image, [inp_text], [out_text, out_audio] ) #[out_audio]) #, out_count])
|
|
|
|
|
123 |
|
124 |
demo_blocks.queue().launch(debug=True)
|
|
|
1 |
+
from huggingface_hub import InferenceClient
|
2 |
from gradio_client import Client
|
3 |
import torch
|
4 |
import nltk # we'll use this to split into sentences
|
5 |
import numpy as np
|
6 |
from transformers import BarkModel, AutoProcessor
|
7 |
nltk.download('punkt')
|
8 |
+
|
9 |
import gradio as gr
|
10 |
+
import os
|
11 |
+
os.environ["GRADIO_TEMP_DIR"] = "/home/yoach/spaces/tmp"
|
12 |
+
|
13 |
|
14 |
def _grab_best_device(use_gpu=True):
|
15 |
if torch.cuda.device_count() > 0 and use_gpu:
|
|
|
20 |
|
21 |
device = _grab_best_device()
|
22 |
|
|
|
|
|
23 |
SYST_PROMPT="""You're the storyteller, crafting a short tale for young listeners. Please abide by these guidelines:
|
24 |
+
- Keep your sentences short, concise and easy to understand.
|
25 |
+
- There should be only the narrator speaking. If there are dialogues, they should be indirect."""
|
26 |
|
27 |
#story_prompt = "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson."
|
28 |
story_prompt = "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom."
|
|
|
30 |
top_p = 0.6
|
31 |
repetition_penalty = 1.2
|
32 |
|
33 |
+
TIMEOUT = int(os.environ.get("TIMEOUT", 45))
|
34 |
+
|
35 |
+
temperature = 0.9
|
36 |
+
top_p = 0.6
|
37 |
+
repetition_penalty = 1.2
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
# TODO: requirements: accelerate optimum
|
43 |
+
|
44 |
+
text_client = InferenceClient(
|
45 |
+
"mistralai/Mistral-7B-Instruct-v0.1",
|
46 |
+
timeout=TIMEOUT,
|
47 |
+
)
|
48 |
+
image_client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/ffe2bn2dk/")
|
49 |
+
image_negative_prompt = "ultrarealistic, soft lighting, 8k, ugly, text, blurry"
|
50 |
+
image_positive_prompt = ""
|
51 |
image_seed = 9
|
52 |
|
53 |
processor = AutoProcessor.from_pretrained("suno/bark")
|
|
|
58 |
|
59 |
# convert to bettertransformer
|
60 |
model = model.to_bettertransformer()
|
61 |
+
BATCH_SIZE = 16
|
62 |
|
63 |
# enable CPU offload
|
64 |
model.enable_cpu_offload()
|
65 |
|
66 |
+
# MISTRAL ONLY
|
67 |
+
default_system_understand_message = (
|
68 |
+
"I understand, I am a Mistral chatbot."
|
69 |
+
)
|
70 |
+
system_understand_message = os.environ.get(
|
71 |
+
"SYSTEM_UNDERSTAND_MESSAGE", default_system_understand_message
|
72 |
+
)
|
73 |
+
|
74 |
+
# Mistral formatter
|
75 |
+
def format_prompt(message):
|
76 |
+
prompt = (
|
77 |
+
"<s>[INST]" + SYST_PROMPT + "[/INST]" + system_understand_message + "</s>"
|
78 |
+
)
|
79 |
+
prompt += f"[INST] {message} [/INST]"
|
80 |
+
return prompt
|
81 |
+
|
82 |
+
|
83 |
+
def generate_story(
|
84 |
+
story_prompt,
|
85 |
+
temperature=0.9,
|
86 |
+
max_new_tokens=1024,
|
87 |
+
top_p=0.95,
|
88 |
+
repetition_penalty=1.0,):
|
89 |
+
|
90 |
+
temperature = float(temperature)
|
91 |
+
if temperature < 1e-2:
|
92 |
+
temperature = 1e-2
|
93 |
+
top_p = float(top_p)
|
94 |
+
|
95 |
+
generate_kwargs = dict(
|
96 |
+
temperature=temperature,
|
97 |
+
max_new_tokens=max_new_tokens,
|
98 |
+
top_p=top_p,
|
99 |
+
repetition_penalty=repetition_penalty,
|
100 |
+
do_sample=True,
|
101 |
+
seed=42,
|
102 |
+
)
|
103 |
+
|
104 |
+
try:
|
105 |
+
output = text_client.text_generation(
|
106 |
+
format_prompt(story_prompt),
|
107 |
+
**generate_kwargs,
|
108 |
+
details=False,
|
109 |
+
return_full_text=False,
|
110 |
+
)
|
111 |
+
except Exception as e:
|
112 |
+
if "Too Many Requests" in str(e):
|
113 |
+
print("ERROR: Too many requests on mistral client")
|
114 |
+
gr.Warning("Unfortunately Mistral is unable to process")
|
115 |
+
output = "Unfortuanately I am not able to process your request now, too many people are asking me !"
|
116 |
+
elif "Model not loaded on the server" in str(e):
|
117 |
+
print("ERROR: Mistral server down")
|
118 |
+
gr.Warning("Unfortunately Mistral LLM is unable to process")
|
119 |
+
output = "Unfortuanately I am not able to process your request now, I have problem with Mistral!"
|
120 |
+
else:
|
121 |
+
print("Unhandled Exception: ", str(e))
|
122 |
+
gr.Warning("Unfortunately Mistral is unable to process")
|
123 |
+
output = "I do not know what happened but I could not understand you."
|
124 |
+
return output
|
125 |
+
|
126 |
+
return output
|
127 |
|
128 |
|
129 |
def generate_audio_and_image(story_prompt, voice_preset=voice_preset):
|
130 |
|
131 |
|
132 |
+
story = generate_story(story_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
print(story)
|
135 |
|
136 |
model_input = story.replace("\n", " ").strip()
|
137 |
model_input = nltk.sent_tokenize(model_input)
|
138 |
|
139 |
+
print("text generated - now calling for image")
|
140 |
+
job_img = image_client.submit(
|
141 |
+
story_prompt+image_positive_prompt, # str in 'parameter_11' Textbox component
|
142 |
+
image_negative_prompt, # str in 'parameter_12' Textbox component
|
143 |
+
25,
|
144 |
+
7,
|
145 |
+
1024,
|
146 |
+
1024,
|
147 |
+
image_seed,
|
148 |
+
fn_index=0,
|
149 |
+
)
|
150 |
+
print("image called - now generating audio")
|
151 |
+
|
152 |
pieces = []
|
153 |
for i in range(0, len(model_input), BATCH_SIZE):
|
154 |
inputs = model_input[BATCH_SIZE*i:min(BATCH_SIZE*(i+1), len(model_input))]
|
|
|
156 |
if len(inputs) != 0:
|
157 |
inputs = processor(inputs, voice_preset=voice_preset)
|
158 |
|
159 |
+
speech_output, output_lengths = model.generate(**inputs.to(device), return_output_lengths=True, min_eos_p=0.2)
|
160 |
|
161 |
+
speech_output = [output[:length].cpu().numpy() for (output,length) in zip(speech_output, output_lengths)]
|
162 |
+
|
163 |
+
print(f"{i}-th part generated")
|
164 |
pieces += [*speech_output, silence.copy()]
|
165 |
|
166 |
+
print("Calling image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
# TODO: if error catch it
|
169 |
+
img = job_img.result()
|
170 |
|
171 |
+
return story, (sampling_rate, np.concatenate(pieces)), img
|
172 |
|
173 |
|
174 |
|
175 |
|
176 |
# Gradio blocks demo
|
177 |
with gr.Blocks() as demo_blocks:
|
178 |
+
gr.Markdown("""<h1 align="center">🐶Children story</h1>""")
|
179 |
gr.HTML("""<h3 style="text-align:center;">📢Audio Streaming powered by Gradio (v3.40.0 onwards)🦾! </h3>""")
|
180 |
with gr.Group():
|
181 |
with gr.Row():
|
|
|
190 |
|
191 |
with gr.Row():
|
192 |
btn = gr.Button("Create a story")
|
193 |
+
|
194 |
+
with gr.Row():
|
195 |
+
with gr.Column(scale=1):
|
196 |
+
image_output = gr.Image(elem_id="gallery")
|
197 |
with gr.Row():
|
198 |
out_audio = gr.Audio(
|
199 |
streaming=False, autoplay=True) # needed to stream output audio
|
200 |
out_text = gr.Text()
|
201 |
+
btn.click(generate_audio_and_image, [inp_text], [out_text, out_audio, image_output] ) #[out_audio]) #, out_count])
|
202 |
+
|
203 |
+
|
204 |
|
205 |
demo_blocks.queue().launch(debug=True)
|