import argparse, os, sys, glob | |
import pathlib | |
directory = pathlib.Path(os.getcwd()) | |
print(directory) | |
sys.path.append(str(directory)) | |
import torch | |
import numpy as np | |
from omegaconf import OmegaConf | |
from ldm.util import instantiate_from_config | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
import pandas as pd | |
from tqdm import tqdm | |
import preprocess.n2s_by_openai as n2s | |
from vocoder.bigvgan.models import VocoderBigVGAN | |
import soundfile | |
import torchaudio, math | |
def load_model_from_config(config, ckpt = None, verbose=True): | |
model = instantiate_from_config(config.model) | |
if ckpt: | |
print(f"Loading model from {ckpt}") | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
sd = pl_sd["state_dict"] | |
m, u = model.load_state_dict(sd, strict=False) | |
if len(m) > 0 and verbose: | |
print("missing keys:") | |
print(m) | |
if len(u) > 0 and verbose: | |
print("unexpected keys:") | |
print(u) | |
else: | |
print(f"Note chat no ckpt is loaded !!!") | |
model.cuda() | |
model.eval() | |
return model | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--prompt", | |
type=str, | |
# default="A large truck driving by as an emergency siren wails and truck horn honks", | |
default='This instrumental song features a relaxing melody with a country feel, accompanied by a guitar, piano, simple percussion, and bass in a slow tempo', | |
help="the prompt to generate" | |
) | |
parser.add_argument( | |
"--sample_rate", | |
type=int, | |
default="16000", | |
help="sample rate of wav" | |
) | |
parser.add_argument( | |
"--test-dataset", | |
default="none", | |
help="test which dataset: testset" | |
) | |
parser.add_argument( | |
"--outdir", | |
type=str, | |
nargs="?", | |
help="dir to write results to", | |
default="outputs/txt2audio-samples" | |
) | |
parser.add_argument( | |
"--ddim_steps", | |
type=int, | |
default=25, | |
help="number of ddim sampling steps", | |
) | |
parser.add_argument( | |
"--n_iter", | |
type=int, | |
default=1, | |
help="sample this often", | |
) | |
parser.add_argument( | |
"--H", | |
type=int, | |
default=20, # keep fix | |
help="latent height, in pixel space", | |
) | |
parser.add_argument( | |
"--W", | |
type=int, | |
default=312, # keep fix | |
help="latent width, in pixel space", | |
) | |
parser.add_argument( | |
"--n_samples", | |
type=int, | |
default=1, | |
help="how many samples to produce for the given prompt", | |
) | |
parser.add_argument( | |
"--scale", | |
type=float, | |
default=5.0, # if it's 1, only condition is taken into consideration | |
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", | |
) | |
parser.add_argument( | |
"-r", | |
"--resume", | |
type=str, | |
const=True, | |
default="", | |
nargs="?", | |
help="resume from logdir or checkpoint in logdir", | |
) | |
parser.add_argument( | |
"-b", | |
"--base", | |
type=str, | |
help="paths to base configs. Loaded from left-to-right. " | |
"Parameters can be overwritten or added with command-line options of the form `--key value`.", | |
default="", | |
) | |
parser.add_argument( | |
"--vocoder-ckpt", | |
type=str, | |
help="paths to vocoder checkpoint", | |
default='vocoder/logs/audioset', | |
) | |
return parser.parse_args() | |
class GenSamples: | |
def __init__(self,opt, model,outpath,config, vocoder = None,save_mel = True,save_wav = True) -> None: | |
self.opt = opt | |
self.model = model | |
self.outpath = outpath | |
if save_wav: | |
assert vocoder is not None | |
self.vocoder = vocoder | |
self.save_mel = save_mel | |
self.save_wav = save_wav | |
self.channel_dim = self.model.channels | |
self.config = config | |
def gen_test_sample(self,prompt, mel_name = None,wav_name = None, gt=None, video=None):# prompt is {'ori_caption':’xxx‘,'struct_caption':'xxx'} | |
uc = None | |
record_dicts = [] | |
if self.opt.scale != 1.0: | |
try: # audiocaps | |
uc = self.model.get_learned_conditioning({'ori_caption': "",'struct_caption': ""}) | |
except: # audioset | |
uc = self.model.get_learned_conditioning(prompt['ori_caption']) | |
for n in range(self.opt.n_iter):# trange(self.opt.n_iter, desc="Sampling"): | |
try: # audiocaps | |
c = self.model.get_learned_conditioning(prompt) # shape:[1,77,1280],即还没有变成句子embedding,仍是每个单词的embedding | |
except: # audioset | |
c = self.model.get_learned_conditioning(prompt['ori_caption']) | |
if self.channel_dim>0: | |
shape = [self.channel_dim, self.opt.H, self.opt.W] # (z_dim, 80//2^x, 848//2^x) | |
else: | |
shape = [1, self.opt.H, self.opt.W] | |
x0 = torch.randn(shape, device=self.model.device) | |
if self.opt.scale == 1: # w/o cfg | |
sample, _ = self.model.sample(c, 1, timesteps=self.opt.ddim_steps, x_latent=x0) | |
else: # cfg | |
sample, _ = self.model.sample_cfg(c, self.opt.scale, uc, 1, timesteps=self.opt.ddim_steps, x_latent=x0) | |
x_samples_ddim = self.model.decode_first_stage(sample) | |
for idx,spec in enumerate(x_samples_ddim): | |
spec = spec.squeeze(0).cpu().numpy() | |
record_dict = {'caption':prompt['ori_caption'][0]} | |
if self.save_mel: | |
mel_path = os.path.join(self.outpath,mel_name+f'_{idx}.npy') | |,spec) | |
record_dict['mel_path'] = mel_path | |
if self.save_wav: | |
wav = self.vocoder.vocode(spec) | |
wav_path = os.path.join(self.outpath,wav_name+f'_{idx}.wav') | |
soundfile.write(wav_path, wav, self.opt.sample_rate) | |
record_dict['audio_path'] = wav_path | |
record_dicts.append(record_dict) | |
# if gt != None: | |
# wav_gt = self.vocoder.vocode(gt) | |
# wav_path = os.path.join(self.outpath, wav_name + f'_gt.wav') | |
# soundfile.write(wav_path, wav_gt, 16000) | |
return record_dicts | |
def main(): | |
opt = parse_args() | |
# torch.manual_seed(55) | |
config = OmegaConf.load(opt.base) | |
model = load_model_from_config(config, opt.resume) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model = | |
os.makedirs(opt.outdir, exist_ok=True) | |
vocoder = VocoderBigVGAN(opt.vocoder_ckpt,device) | |
generator = GenSamples(opt, model,opt.outdir,config, vocoder,save_mel = False,save_wav = True) | |
csv_dicts = [] | |
with torch.no_grad(): | |
with model.ema_scope(): | |
if opt.test_dataset != 'none': | |
if opt.test_dataset == 'testset': | |
test_dataset = instantiate_from_config(config['test_dataset']) | |
video = None | |
else: | |
raise NotImplementedError | |
print(f"Dataset: {type(test_dataset)} LEN: {len(test_dataset)}") | |
temp_n = 0 | |
int_s = 0 | |
for item in tqdm(test_dataset): | |
int_s += 1 | |
if int_s < 2: | |
continue | |
# int_s += 1 | |
prompt,f_name, gt = item['caption'],item['f_name'],item['image'] | |
vname_num_split_index = f_name.rfind('_')# file_names[b]:video_name+'_'+num | |
v_n,num = f_name[:vname_num_split_index],f_name[vname_num_split_index+1:] | |
mel_name = f'{v_n}_sample_{num}' | |
wav_name = f'{v_n}_sample_{num}' | |
# write_gt_wav(v_n,opt.test_dataset2,opt.outdir,opt.sample_rate) | |
csv_dicts.extend(generator.gen_test_sample(prompt, mel_name=mel_name ,wav_name=wav_name, gt=gt, video=video)) | |
if temp_n > 1: | |
break | |
temp_n += 1 | |
df = pd.DataFrame.from_dict(csv_dicts) | |
df.to_csv(os.path.join(opt.outdir,'result.csv'),sep='\t',index=False) | |
else: | |
ori_caption = opt.prompt | |
struct_caption = n2s.get_struct(ori_caption) | |
# struct_caption = f'<{ori_caption}& all>' | |
print(f"The structed caption by Chatgpt is : {struct_caption}") | |
wav_name = f'{ori_caption.strip().replace(" ", "-")}' | |
prompt = {'ori_caption':[ori_caption],'struct_caption':[struct_caption]} | |
generator.gen_test_sample(prompt, wav_name=wav_name) | |
print(f"Your samples are ready and waiting four you here: \n{opt.outdir} \nEnjoy.") | |
if __name__ == "__main__": | |
main() | |