Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,811 Bytes
f3f28d3 aca4b77 f3f28d3 aca4b77 f3f28d3 0faef57 7e86f55 aca4b77 ccb7c0b aca4b77 f3f28d3 aca4b77 f3f28d3 012fbfa 9a7b1fe 012fbfa f3f28d3 012fbfa f3f28d3 9a7b1fe 012fbfa f3f28d3 9a7b1fe b3760ba 9a7b1fe f3f28d3 b3760ba f3f28d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import json
import numpy as np
import torch
import soundfile as sf
import gradio as gr
from diffusers import DDPMScheduler
from pico_model import PicoDiffusion, build_pretrained_models
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class InferRunner:
def __init__(self, device):
vae_config = json.load(open("ckpts/ldm/vae_config.json"))
self.vae = AutoencoderKL(**vae_config).to(device)
vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
self.vae.load_state_dict(vae_weights)
train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
self.pico_model = PicoDiffusion(
scheduler_name=train_args.scheduler_name,
unet_model_config_path=train_args.unet_model_config,
snr_gamma=train_args.snr_gamma,
freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
diffusion_pt="ckpts/pico_model/diffusion.pt",
).eval().to(device)
self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
device = "cuda" if torch.cuda.is_available() else "cpu"
runner = InferRunner(device)
event_list = [
"burping_belching", # 0
"car_horn_honking", #
"cat_meowing", #
"cow_mooing", #
"dog_barking", #
"door_knocking", #
"door_slamming", #
"explosion", #
"gunshot", # 8
"sheep_goat_bleating", #
"sneeze", #
"spraying", #
"thump_thud", #
"train_horn", #
"tapping_clicking_clanking", #
"woman_laughing", #
"duck_quacking", # 16
"whistling", #
]
def infer(caption, num_steps=200, guidance_scale=3.0, audio_len=16000*10):
with torch.no_grad():
latents = runner.pico_model.demo_inference(caption, runner.scheduler, num_steps=num_steps, guidance_scale=guidance_scale, num_samples_per_prompt=1, disable_progress=True)
mel = runner.vae.decode_first_stage(latents)
wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
outpath = f"output.wav"
sf.write(outpath, wave, samplerate=16000, subtype='PCM_16')
return outpath
gr.Markdown("## PicoAudio")
gr.Markdown("18 events: " + ", ".join(event_list))
prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
outaudio = gr.Audio()
num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
gr_interface = gr.Interface(
fn=infer,
inputs=[prompt, num_steps, guidance_scale],
outputs=[outaudio],
# title="
allow_flagging=False,
examples=[
["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
["dog_barking at 0.562-2.562_4.25-6.25."],
["cow_mooing at 0.958-3.582_5.272-7.896."],
],
cache_examples="lazy", # Turn on to cache.
)
gr_interface.queue(10).launch()
|