Spaces:
Paused
Paused
from huggingface_hub import InferenceClient | |
from gradio_client import Client | |
import torch | |
import nltk # we'll use this to split into sentences | |
import numpy as np | |
from transformers import BarkModel, AutoProcessor | |
nltk.download('punkt') | |
import gradio as gr | |
import os | |
def _grab_best_device(use_gpu=True): | |
if torch.cuda.device_count() > 0 and use_gpu: | |
device = "cuda" | |
else: | |
device = "cpu" | |
return device | |
device = _grab_best_device() | |
SYST_PROMPT="""You're the storyteller, crafting a short tale for young listeners. Please abide by these guidelines: | |
- Keep your sentences short, concise and easy to understand. | |
- There should be only the narrator speaking. If there are dialogues, they should be indirect.""" | |
#story_prompt = "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson." | |
story_prompt = "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom." | |
temperature = 0.9 | |
top_p = 0.6 | |
repetition_penalty = 1.2 | |
TIMEOUT = int(os.environ.get("TIMEOUT", 45)) | |
temperature = 0.9 | |
top_p = 0.6 | |
repetition_penalty = 1.2 | |
# TODO: requirements: accelerate optimum | |
text_client = InferenceClient( | |
"mistralai/Mistral-7B-Instruct-v0.1", | |
timeout=TIMEOUT, | |
) | |
image_client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/ffe2bn2dk/") | |
image_negative_prompt = "ultrarealistic, soft lighting, 8k, ugly, text, blurry" | |
image_positive_prompt = "" | |
image_seed = 6 | |
processor = AutoProcessor.from_pretrained("suno/bark") | |
def format_speaker_key(key): | |
key = key.replace("v2/", "").split("_") | |
return f"Speaker {key[2]} ({key[0]})" | |
voice_presets = [key for key in processor.speaker_embeddings.keys() if "v2/en" in key] | |
voice_presets_dict = { | |
format_speaker_key(key): key for key in voice_presets | |
} | |
model = BarkModel.from_pretrained("suno/bark", torch_dtype=torch.float16, use_flash_attention_2=True).to(device) | |
sampling_rate = model.generation_config.sample_rate | |
silence = np.zeros(int(0.25 * sampling_rate)) # quarter second of silence | |
voice_preset = "v2/en_speaker_6" | |
model = model.enable_cpu_offload() | |
BATCH_SIZE = 32 | |
# enable CPU offload | |
model.enable_cpu_offload() | |
# MISTRAL ONLY | |
default_system_understand_message = ( | |
"I understand, I am a Mistral chatbot." | |
) | |
system_understand_message = os.environ.get( | |
"SYSTEM_UNDERSTAND_MESSAGE", default_system_understand_message | |
) | |
# Mistral formatter | |
def format_prompt(message): | |
prompt = ( | |
"<s>[INST]" + SYST_PROMPT + "[/INST]" + system_understand_message + "</s>" | |
) | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
def generate_story( | |
story_prompt, | |
temperature=0.9, | |
max_new_tokens=1024, | |
top_p=0.95, | |
repetition_penalty=1.0,): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
try: | |
output = text_client.text_generation( | |
format_prompt(story_prompt), | |
**generate_kwargs, | |
details=False, | |
return_full_text=False, | |
) | |
except Exception as e: | |
if "Too Many Requests" in str(e): | |
print("ERROR: Too many requests on mistral client") | |
gr.Warning("Unfortunately Mistral is unable to process") | |
output = "Unfortuanately I am not able to process your request now, too many people are asking me !" | |
elif "Model not loaded on the server" in str(e): | |
print("ERROR: Mistral server down") | |
gr.Warning("Unfortunately Mistral LLM is unable to process") | |
output = "Unfortuanately I am not able to process your request now, I have problem with Mistral!" | |
else: | |
print("Unhandled Exception: ", str(e)) | |
gr.Warning("Unfortunately Mistral is unable to process") | |
output = "I do not know what happened but I could not understand you." | |
return output | |
return output | |
def generate_audio_and_image(story_prompt, voice_preset="Speaker 3 (en)"): | |
story = generate_story(story_prompt) | |
print(story) | |
model_input = story.replace("\n", " ").strip() | |
model_input = nltk.sent_tokenize(model_input) | |
print("text generated - now calling for image") | |
job_img = image_client.submit( | |
story_prompt+image_positive_prompt, # str in 'parameter_11' Textbox component | |
image_negative_prompt, # str in 'parameter_12' Textbox component | |
25, | |
7, | |
1024, | |
1024, | |
image_seed, | |
fn_index=0, | |
) | |
print("image called - now generating audio") | |
pieces = [] | |
for i in range(0, len(model_input), BATCH_SIZE): | |
inputs = model_input[i:min(i + BATCH_SIZE, len(model_input))] | |
if len(inputs) != 0: | |
inputs = processor(inputs, voice_preset=voice_presets_dict[voice_preset]) | |
speech_output, output_lengths = model.generate(**inputs.to(device), return_output_lengths=True, min_eos_p=0.2) | |
speech_output = [output[:length].cpu().numpy() for (output,length) in zip(speech_output, output_lengths)] | |
print(f"{i}-th part generated") | |
pieces += [*speech_output, silence.copy()] | |
print("Calling image") | |
try: | |
img = job_img.result() | |
except Exception as e: | |
print("Unhandled Exception: ", str(e)) | |
gr.Warning("Unfortunately there was an issue when generating the image with SDXL.") | |
img = None | |
return story, (sampling_rate, np.concatenate(pieces)), img | |
# Gradio blocks demo | |
with gr.Blocks() as demo_blocks: | |
gr.Markdown("""<h1 align="center">🐶Children story</h1>""") | |
gr.HTML("""<h3 style="text-align:center;">Let Mistral tell you a story</h3>""") | |
with gr.Group(): | |
with gr.Row(): | |
inp_text = gr.Textbox(label="Story prompt", info="Enter text here") | |
with gr.Row(): | |
with gr.Accordion("Advanced settings", open=False): | |
voice_preset = gr.Dropdown( | |
voice_presets_dict, | |
value="Speaker 6 (en)", | |
label="Available speakers", | |
) | |
with gr.Row(): | |
btn = gr.Button("Create a story") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_output = gr.Image(elem_id="gallery") | |
with gr.Row(): | |
out_audio = gr.Audio( | |
streaming=False, autoplay=True) # needed to stream output audio | |
out_text = gr.Text() | |
btn.click(generate_audio_and_image, [inp_text, voice_preset], [out_text, out_audio, image_output] ) #[out_audio]) #, out_count]) | |
with gr.Row(): | |
gr.Examples( | |
[ | |
"A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson.", | |
"A princess breaks free from a dragon's grip. This evocates women empowerement and freedom.", | |
"Tell me about the wonders of the world.", | |
], | |
[inp_text], | |
[out_text, out_audio, image_output], | |
generate_audio_and_image, | |
cache_examples=True, | |
) | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
This Space uses **[Bark](https://huggingface.co/docs/transformers/main/en/model_doc/bark)**, [Mistral-7b-instruct](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [Fast SD-XL](https://huggingface.co/spaces/openskyml/fast-sdxl-stable-diffusion-xl)! | |
""" | |
) | |
demo_blocks.queue().launch(debug=True) |