File size: 2,595 Bytes
4ee33aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from tqdm import tqdm
from ttts.utils.infer_utils import load_model
import json
import torch.nn.functional as F
import torch
import os
def read_jsonl(path):
    path = os.path.expanduser(path)
    with open(path, 'r') as f:
        json_str = f.read()
    data_list = []
    for line in json_str.splitlines():
        data = json.loads(line)
        data_list.append(data)
    return data_list
def classify_audio_clip(clip, classifier):
    """
    Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
    :param clip: torch tensor containing audio waveform data (get it from load_audio)
    :return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
    """
    with torch.no_grad():
        results = F.softmax(classifier(clip), dim=-1)
    return results

class MelDataset(torch.utils.data.Dataset):
    def __init__(self,paths):
        super().__init__()
        self.paths = paths
        self.pad_to=700
    def __getitem__(self,index):
        path = self.paths[index]
        try:
            mel = torch.load(path+'.mel.pth')
        except:
            mel = torch.zeros((1,100,self.pad_to))
        if mel.shape[-1] >= self.pad_to:
            start = torch.randint(0, mel.shape[-1] - self.pad_to+1, (1,))
            mel = mel[:, :, start:start+self.pad_to]
        else:
            padding_needed = self.pad_to - mel.shape[-1]
            mel = F.pad(mel, (0,padding_needed))
        mel = mel.squeeze(0)
        return mel,path
    def __len__(self):
        return len(self.paths)

if __name__=='__main__':
    model_path = '/home/hyc/tortoise_plus_zh/ttts/classifier/logs/2023-11-23-17-34-45/model-9.pt'
    config_path = '~/tortoise_plus_zh/ttts/classifier/config.json'
    device = 'cuda'
    classifier = load_model('classifier', model_path, config_path, device)
    jsonl_path = '~/tortoise_plus_zh/ttts/datasets/all_data.jsonl'
    
    audiopaths_and_text = read_jsonl(jsonl_path)
    audio_paths = [x['path'] for x in audiopaths_and_text]
    ds = MelDataset(audio_paths)
    dl = torch.utils.data.DataLoader(ds,batch_size=1024,num_workers=16)
    for _,batch in tqdm(enumerate(dl),total=len(dl)):
        mels, paths = batch
        mels = mels.to(device)
        label = classify_audio_clip(mels,classifier)
        for i in range(label.shape[0]):
            if label[i][0]<0.1:
                with open('ttts/classifier/noise_files.txt','a') as f:
                    # print(os.path.join(os.getcwd(),paths[i]))
                    f.write(os.path.join(os.getcwd(),paths[i])+'\n')