Spaces:
Sleeping
Sleeping
from tqdm import tqdm | |
from ttts.utils.infer_utils import load_model | |
import json | |
import torch.nn.functional as F | |
import torch | |
import os | |
def read_jsonl(path): | |
path = os.path.expanduser(path) | |
with open(path, 'r') as f: | |
json_str = f.read() | |
data_list = [] | |
for line in json_str.splitlines(): | |
data = json.loads(line) | |
data_list.append(data) | |
return data_list | |
def classify_audio_clip(clip, classifier): | |
""" | |
Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. | |
:param clip: torch tensor containing audio waveform data (get it from load_audio) | |
:return: True if the clip was classified as coming from Tortoise and false if it was classified as real. | |
""" | |
with torch.no_grad(): | |
results = F.softmax(classifier(clip), dim=-1) | |
return results | |
class MelDataset(torch.utils.data.Dataset): | |
def __init__(self,paths): | |
super().__init__() | |
self.paths = paths | |
self.pad_to=700 | |
def __getitem__(self,index): | |
path = self.paths[index] | |
try: | |
mel = torch.load(path+'.mel.pth') | |
except: | |
mel = torch.zeros((1,100,self.pad_to)) | |
if mel.shape[-1] >= self.pad_to: | |
start = torch.randint(0, mel.shape[-1] - self.pad_to+1, (1,)) | |
mel = mel[:, :, start:start+self.pad_to] | |
else: | |
padding_needed = self.pad_to - mel.shape[-1] | |
mel = F.pad(mel, (0,padding_needed)) | |
mel = mel.squeeze(0) | |
return mel,path | |
def __len__(self): | |
return len(self.paths) | |
if __name__=='__main__': | |
model_path = '/home/hyc/tortoise_plus_zh/ttts/classifier/logs/2023-11-23-17-34-45/model-9.pt' | |
config_path = '~/tortoise_plus_zh/ttts/classifier/config.json' | |
device = 'cuda' | |
classifier = load_model('classifier', model_path, config_path, device) | |
jsonl_path = '~/tortoise_plus_zh/ttts/datasets/all_data.jsonl' | |
audiopaths_and_text = read_jsonl(jsonl_path) | |
audio_paths = [x['path'] for x in audiopaths_and_text] | |
ds = MelDataset(audio_paths) | |
dl = torch.utils.data.DataLoader(ds,batch_size=1024,num_workers=16) | |
for _,batch in tqdm(enumerate(dl),total=len(dl)): | |
mels, paths = batch | |
mels = mels.to(device) | |
label = classify_audio_clip(mels,classifier) | |
for i in range(label.shape[0]): | |
if label[i][0]<0.1: | |
with open('ttts/classifier/noise_files.txt','a') as f: | |
# print(os.path.join(os.getcwd(),paths[i])) | |
f.write(os.path.join(os.getcwd(),paths[i])+'\n') | |