children-story / app.py
ylacombe's picture
Update app.py
5bb18c5
raw
history blame
7.87 kB
from huggingface_hub import InferenceClient
from gradio_client import Client
import torch
import nltk # we'll use this to split into sentences
import numpy as np
from transformers import BarkModel, AutoProcessor
nltk.download('punkt')
import gradio as gr
import os
def _grab_best_device(use_gpu=True):
if torch.cuda.device_count() > 0 and use_gpu:
device = "cuda"
else:
device = "cpu"
return device
device = _grab_best_device()
SYST_PROMPT="""You're the storyteller, crafting a short tale for young listeners. Please abide by these guidelines:
- Keep your sentences short, concise and easy to understand.
- There should be only the narrator speaking. If there are dialogues, they should be indirect."""
#story_prompt = "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson."
story_prompt = "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom."
temperature = 0.9
top_p = 0.6
repetition_penalty = 1.2
TIMEOUT = int(os.environ.get("TIMEOUT", 45))
temperature = 0.9
top_p = 0.6
repetition_penalty = 1.2
# TODO: requirements: accelerate optimum
text_client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1",
timeout=TIMEOUT,
)
image_client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/ffe2bn2dk/")
image_negative_prompt = "ultrarealistic, soft lighting, 8k, ugly, text, blurry"
image_positive_prompt = ""
image_seed = 6
processor = AutoProcessor.from_pretrained("suno/bark")
def format_speaker_key(key):
key = key.replace("v2/", "").split("_")
return f"Speaker {key[2]} ({key[0]})"
voice_presets = [key for key in processor.speaker_embeddings.keys() if "v2/en" in key]
voice_presets_dict = {
format_speaker_key(key): key for key in voice_presets
}
model = BarkModel.from_pretrained("suno/bark", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
sampling_rate = model.generation_config.sample_rate
silence = np.zeros(int(0.25 * sampling_rate)) # quarter second of silence
voice_preset = "v2/en_speaker_6"
model = model.enable_cpu_offload()
BATCH_SIZE = 32
# enable CPU offload
model.enable_cpu_offload()
# MISTRAL ONLY
default_system_understand_message = (
"I understand, I am a Mistral chatbot."
)
system_understand_message = os.environ.get(
"SYSTEM_UNDERSTAND_MESSAGE", default_system_understand_message
)
# Mistral formatter
def format_prompt(message):
prompt = (
"<s>[INST]" + SYST_PROMPT + "[/INST]" + system_understand_message + "</s>"
)
prompt += f"[INST] {message} [/INST]"
return prompt
def generate_story(
story_prompt,
temperature=0.9,
max_new_tokens=1024,
top_p=0.95,
repetition_penalty=1.0,):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
try:
output = text_client.text_generation(
format_prompt(story_prompt),
**generate_kwargs,
details=False,
return_full_text=False,
)
except Exception as e:
if "Too Many Requests" in str(e):
print("ERROR: Too many requests on mistral client")
gr.Warning("Unfortunately Mistral is unable to process")
output = "Unfortuanately I am not able to process your request now, too many people are asking me !"
elif "Model not loaded on the server" in str(e):
print("ERROR: Mistral server down")
gr.Warning("Unfortunately Mistral LLM is unable to process")
output = "Unfortuanately I am not able to process your request now, I have problem with Mistral!"
else:
print("Unhandled Exception: ", str(e))
gr.Warning("Unfortunately Mistral is unable to process")
output = "I do not know what happened but I could not understand you."
return output
return output
def generate_audio_and_image(story_prompt, voice_preset="Speaker 3 (en)"):
story = generate_story(story_prompt)
print(story)
model_input = story.replace("\n", " ").strip()
model_input = nltk.sent_tokenize(model_input)
print("text generated - now calling for image")
job_img = image_client.submit(
story_prompt+image_positive_prompt, # str in 'parameter_11' Textbox component
image_negative_prompt, # str in 'parameter_12' Textbox component
25,
7,
1024,
1024,
image_seed,
fn_index=0,
)
print("image called - now generating audio")
pieces = []
for i in range(0, len(model_input), BATCH_SIZE):
inputs = model_input[i:min(i + BATCH_SIZE, len(model_input))]
if len(inputs) != 0:
inputs = processor(inputs, voice_preset=voice_presets_dict[voice_preset])
speech_output, output_lengths = model.generate(**inputs.to(device), return_output_lengths=True, min_eos_p=0.2)
speech_output = [output[:length].cpu().numpy() for (output,length) in zip(speech_output, output_lengths)]
print(f"{i}-th part generated")
pieces += [*speech_output, silence.copy()]
print("Calling image")
try:
img = job_img.result()
except Exception as e:
print("Unhandled Exception: ", str(e))
gr.Warning("Unfortunately there was an issue when generating the image with SDXL.")
img = None
return story, (sampling_rate, np.concatenate(pieces)), img
# Gradio blocks demo
with gr.Blocks() as demo_blocks:
gr.Markdown("""<h1 align="center">🐶Children story</h1>""")
gr.HTML("""<h3 style="text-align:center;">Let Mistral tell you a story</h3>""")
with gr.Group():
with gr.Row():
inp_text = gr.Textbox(label="Story prompt", info="Enter text here")
with gr.Row():
with gr.Accordion("Advanced settings", open=False):
voice_preset = gr.Dropdown(
voice_presets_dict,
value="Speaker 6 (en)",
label="Available speakers",
)
with gr.Row():
btn = gr.Button("Create a story")
with gr.Row():
with gr.Column(scale=1):
image_output = gr.Image(elem_id="gallery")
with gr.Row():
out_audio = gr.Audio(
streaming=False, autoplay=True) # needed to stream output audio
out_text = gr.Text()
btn.click(generate_audio_and_image, [inp_text, voice_preset], [out_text, out_audio, image_output] ) #[out_audio]) #, out_count])
with gr.Row():
gr.Examples(
[
"A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson.",
"A princess breaks free from a dragon's grip. This evocates women empowerement and freedom.",
"Tell me about the wonders of the world.",
],
[inp_text],
[out_text, out_audio, image_output],
generate_audio_and_image,
cache_examples=True,
)
with gr.Row():
gr.Markdown(
"""
This Space uses **[Bark](https://huggingface.co/docs/transformers/main/en/model_doc/bark)**, [Mistral-7b-instruct](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [Fast SD-XL](https://huggingface.co/spaces/openskyml/fast-sdxl-stable-diffusion-xl)!
"""
)
demo_blocks.queue().launch(debug=True)