TTTS / ttts /classifier /infer.py
mrfakename's picture
Add source code
4ee33aa
raw
history blame
2.6 kB
from tqdm import tqdm
from ttts.utils.infer_utils import load_model
import json
import torch.nn.functional as F
import torch
import os
def read_jsonl(path):
path = os.path.expanduser(path)
with open(path, 'r') as f:
json_str = f.read()
data_list = []
for line in json_str.splitlines():
data = json.loads(line)
data_list.append(data)
return data_list
def classify_audio_clip(clip, classifier):
"""
Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
:param clip: torch tensor containing audio waveform data (get it from load_audio)
:return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
"""
with torch.no_grad():
results = F.softmax(classifier(clip), dim=-1)
return results
class MelDataset(torch.utils.data.Dataset):
def __init__(self,paths):
super().__init__()
self.paths = paths
self.pad_to=700
def __getitem__(self,index):
path = self.paths[index]
try:
mel = torch.load(path+'.mel.pth')
except:
mel = torch.zeros((1,100,self.pad_to))
if mel.shape[-1] >= self.pad_to:
start = torch.randint(0, mel.shape[-1] - self.pad_to+1, (1,))
mel = mel[:, :, start:start+self.pad_to]
else:
padding_needed = self.pad_to - mel.shape[-1]
mel = F.pad(mel, (0,padding_needed))
mel = mel.squeeze(0)
return mel,path
def __len__(self):
return len(self.paths)
if __name__=='__main__':
model_path = '/home/hyc/tortoise_plus_zh/ttts/classifier/logs/2023-11-23-17-34-45/model-9.pt'
config_path = '~/tortoise_plus_zh/ttts/classifier/config.json'
device = 'cuda'
classifier = load_model('classifier', model_path, config_path, device)
jsonl_path = '~/tortoise_plus_zh/ttts/datasets/all_data.jsonl'
audiopaths_and_text = read_jsonl(jsonl_path)
audio_paths = [x['path'] for x in audiopaths_and_text]
ds = MelDataset(audio_paths)
dl = torch.utils.data.DataLoader(ds,batch_size=1024,num_workers=16)
for _,batch in tqdm(enumerate(dl),total=len(dl)):
mels, paths = batch
mels = mels.to(device)
label = classify_audio_clip(mels,classifier)
for i in range(label.shape[0]):
if label[i][0]<0.1:
with open('ttts/classifier/noise_files.txt','a') as f:
# print(os.path.join(os.getcwd(),paths[i]))
f.write(os.path.join(os.getcwd(),paths[i])+'\n')