Spaces:
Paused
Paused
from gradio_client import Client | |
import torch | |
import nltk # we'll use this to split into sentences | |
import numpy as np | |
from transformers import BarkModel, AutoProcessor | |
import gradio as gr | |
def _grab_best_device(use_gpu=True): | |
if torch.cuda.device_count() > 0 and use_gpu: | |
device = "cuda" | |
else: | |
device = "cpu" | |
return device | |
device = _grab_best_device() | |
BATCH_SIZE = 8 | |
SYST_PROMPT="""You're the storyteller, crafting a short tale for young listeners. Please abide by these guidelines: | |
- Keep your sentences concise and easy to understand. | |
- There should be only the narrator speaking. No dialogues.""" | |
#story_prompt = "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson." | |
story_prompt = "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom." | |
temperature = 0.9 | |
top_p = 0.6 | |
repetition_penalty = 1.2 | |
text_client = Client("https://ysharma-explore-llamav2-with-tgi.hf.space/") | |
image_client = Client("prodia/fast-stable-diffusion") | |
image_negative_prompt = "ultrarealistic, soft lighting, 8k, ugly" | |
image_positive_prompt = ". Cartoon, anime" | |
image_seed = 9 | |
processor = AutoProcessor.from_pretrained("suno/bark") | |
model = BarkModel.from_pretrained("suno/bark", torch_dtype=torch.float16).to(device) | |
sampling_rate = model.generation_config.sample_rate | |
silence = np.zeros(int(0.25 * sampling_rate)) # quarter second of silence | |
voice_preset = "v2/en_speaker_6" | |
# convert to bettertransformer | |
model = model.to_bettertransformer() | |
# enable CPU offload | |
model.enable_cpu_offload() | |
def generate_audio_and_image(story_prompt, voice_preset=voice_preset): | |
story = text_client.predict( | |
story_prompt, | |
SYST_PROMPT, | |
temperature, | |
4096, | |
temperature, | |
repetition_penalty, | |
api_name="/chat" | |
) | |
print(story) | |
model_input = story.replace("\n", " ").strip() | |
model_input = nltk.sent_tokenize(model_input) | |
pieces = [] | |
for i in range(0, len(model_input), BATCH_SIZE): | |
inputs = model_input[BATCH_SIZE*i:min(BATCH_SIZE*(i+1), len(model_input))] | |
if len(inputs) != 0: | |
inputs = processor(inputs, voice_preset=voice_preset) | |
speech_output = model.generate(**inputs.to(device)).cpu().numpy() | |
pieces += [*speech_output, silence.copy()] | |
#job_img = image_client.submit( | |
# story_prompt+image_positive_prompt, # str in 'parameter_11' Textbox component | |
# image_negative_prompt, # str in 'parameter_12' Textbox component | |
# "absolutereality_v181.safetensors [3d9d4d2b]", # str (Option from: ['absolutereality_V16.safetensors [37db0fc3]', 'absolutereality_v181.safetensors [3d9d4d2b]', 'analog-diffusion-1.0.ckpt [9ca13f02]', 'anythingv3_0-pruned.ckpt [2700c435]', 'anything-v4.5-pruned.ckpt [65745d25]', 'anythingV5_PrtRE.safetensors [893e49b9]', 'AOM3A3_orangemixs.safetensors [9600da17]', 'childrensStories_v13D.safetensors [9dfaabcb]', 'childrensStories_v1SemiReal.safetensors [a1c56dbb]', 'childrensStories_v1ToonAnime.safetensors [2ec7b88b]', 'cyberrealistic_v33.safetensors [82b0d085]', 'deliberate_v2.safetensors [10ec4b29]', 'deliberate_v3.safetensors [afd9d2d4]', 'dreamlike-anime-1.0.safetensors [4520e090]', 'dreamlike-diffusion-1.0.safetensors [5c9fd6e0]', 'dreamlike-photoreal-2.0.safetensors [fdcf65e7]', 'dreamshaper_6BakedVae.safetensors [114c8abb]', 'dreamshaper_7.safetensors [5cf5ae06]', 'dreamshaper_8.safetensors [9d40847d]', 'edgeOfRealism_eorV20.safetensors [3ed5de15]', 'EimisAnimeDiffusion_V1.ckpt [4f828a15]', 'elldreths-vivid-mix.safetensors [342d9d26]', 'epicrealism_naturalSinRC1VAE.safetensors [90a4c676]', 'ICantBelieveItsNotPhotography_seco.safetensors [4e7a3dfd]', 'juggernaut_aftermath.safetensors [5e20c455]', 'lyriel_v16.safetensors [68fceea2]', 'mechamix_v10.safetensors [ee685731]', 'meinamix_meinaV9.safetensors [2ec66ab0]', 'meinamix_meinaV11.safetensors [b56ce717]', 'openjourney_V4.ckpt [ca2f377f]', 'portraitplus_V1.0.safetensors [1400e684]', 'Realistic_Vision_V1.4-pruned-fp16.safetensors [8d21810b]', 'Realistic_Vision_V2.0.safetensors [79587710]', 'Realistic_Vision_V4.0.safetensors [29a7afaa]', 'Realistic_Vision_V5.0.safetensors [614d1063]', 'redshift_diffusion-V10.safetensors [1400e684]', 'revAnimated_v122.safetensors [3f4fefd9]', 'rundiffusionFX25D_v10.safetensors [cd12b0ee]', 'rundiffusionFX_v10.safetensors [cd4e694d]', 'sdv1_4.ckpt [7460a6fa]', 'v1-5-pruned-emaonly.safetensors [d7049739]', 'shoninsBeautiful_v10.safetensors [25d8c546]', 'theallys-mix-ii-churned.safetensors [5d9225a4]', 'timeless-1.0.ckpt [7c4971d4]', 'toonyou_beta6.safetensors [980f6b15]']) | |
# 25, | |
# "Euler a", | |
# 7, | |
# 512, | |
# 512, | |
# image_seed, | |
# "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png,https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", # str (path to directory with images and a file associating images with captions called captions.json) | |
# fn_index=0 | |
#) | |
#img = job_img.result() | |
return story, (sampling_rate, np.concatenate(pieces)) | |
# Gradio blocks demo | |
with gr.Blocks() as demo_blocks: | |
gr.Markdown("""<h1 align="center">🐶Children story<</h1>""") | |
gr.HTML("""<h3 style="text-align:center;">📢Audio Streaming powered by Gradio (v3.40.0 onwards)🦾! </h3>""") | |
with gr.Group(): | |
with gr.Row(): | |
inp_text = gr.Textbox(label="Story prompt", info="Enter text here") | |
#dd = gr.Dropdown( | |
# speaker_embeddings, | |
# value=None, | |
# label="Available voice presets", | |
# info="Defaults to no speaker embeddings!" | |
# ) | |
with gr.Row(): | |
btn = gr.Button("Create a story") | |
with gr.Row(): | |
out_audio = gr.Audio( | |
streaming=False, autoplay=True) # needed to stream output audio | |
out_text = gr.Text() | |
btn.click(generate_audio_and_image, [inp_text], [out_text, out_audio] ) #[out_audio]) #, out_count]) | |
demo_blocks.queue().launch(debug=True) |