File size: 5,301 Bytes
09d8de6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import numpy as np
import torch 
import pandas as pd
from data.filter_data import get_event_list



class Text_Onset_2_Audio_Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, args):

        self.captions = list(dataset[args.text_column])
        self.audios = list(dataset[args.audio_column])
        self.onsets = list(dataset[args.onset_column])
        self.indices = list(range(len(self.captions)))

        self.mapper = {}
        for index, audio, caption, onset in zip(self.indices, self.audios, self.captions, self.onsets):
            self.mapper[index] = [audio, caption, onset]

        num_examples = args.num_examples
        if num_examples != -1:
            self.captions, self.audios, self.onsets = self.captions[:num_examples], self.audios[:num_examples], self.onsets[:num_examples]
            self.indices = self.indices[:num_examples]
        self.class2id = {event: idx for idx, event in enumerate(args.event_list)}

    def decode_data(self, line_onset_str):
        # data    { "location":     audio_path,
        #           "captions" :    "event1 n times and event2 n times",
        #           "onset_str":        "event1__onset1-offset1_onset2-offset2--event2__onset1-offset1"}
        
        line_onset_index = np.zeros((32, 256))
        line_event = []
        for event_onset in line_onset_str.split('--'):
            # event_onset : event1__onset1-offset1_onset2-offset2
            (event, instance) = event_onset.split('__')
            line_event.append(event)
            # instance : onset1-offset1_onset2-offset2
            for start_end in instance.split('_'):
                (start, end) = start_end.split('-')         
                start, end = int(float(start)*250/10), int(float(end)*250/10)
                if end > 255: break
                line_onset_index[self.class2id[event], start: end] = 1
        line_event_str = " and ".join(line_event)
        return line_onset_index, line_event_str
      
    def __len__(self):
        return len(self.captions)

    def get_num_instances(self):
        return len(self.captions)

    def __getitem__(self, index):
        onset_str, filename, idx, caption = self.onsets[index], self.audios[index], self.indices[index], self.captions[index]
        onset, _ = self.decode_data(onset_str)
        #"onset_str":        "event1__onset1-offset1_onset2-offset2--event2__onset1-offset1"
        #assert len(onset_str.split("--")) == 1
        first_class_id = self.class2id[onset_str.split("__")[0]]
        return idx, onset, first_class_id, filename, caption, onset_str

    def collate_fn(self, data):
        dat = pd.DataFrame(data)
        batch = []
        for i in dat:     
            if i==1:
                batch.append(torch.tensor(np.array(dat[i].tolist()), dtype=torch.float32))
            elif i==2:
                batch.append(torch.tensor(dat[i]))
            else:
                batch.append(dat[i].tolist())
        return batch
    
class Clap_Onset_2_Audio_Dataset(Text_Onset_2_Audio_Dataset):
    def __init__(self, dataset, args):
        super().__init__(dataset, args)
        import laion_clap
        from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict

        self.clap_scorer = laion_clap.CLAP_Module(enable_fusion=False)
        ckpt_path = 'miniconda3/envs/py3.10.11/lib/python3.10/site-packages/laion_clap/630k-audioset-best.pt'
        ckpt = clap_load_state_dict(ckpt_path, skip_params=True)
        del_parameter_key = ["text_branch.embeddings.position_ids"]
        ckpt = {"model."+k:v for k, v in ckpt.items() if k not in del_parameter_key}
        self.clap_scorer.load_state_dict(ckpt)

    def __getitem__(self, index):
        onset_str, filename, idx, caption = self.onsets[index], self.audios[index], self.indices[index], self.captions[index]
        onset, event = self.decode_data(onset_str)
        with torch.no_grad():
            clap_embed = self.clap_scorer.get_text_embedding([event, ""], use_tensor=False)[0]
        return idx, onset, clap_embed, filename, caption, onset_str

    def collate_fn(self, data):
        dat = pd.DataFrame(data)
        batch = []
        for i in dat:     
            if i==1 or i==2:
                batch.append(torch.tensor(np.array(dat[i].tolist()), dtype=torch.float32))
            else:
                batch.append(dat[i].tolist())
        return batch


if __name__ == "__main__":
    import torch
    from torch.utils.data import Dataset, DataLoader
    import datasets
    import argparse
    import sys
    
    import models.controllable_dataset as ConDataset
    from data_utils.filter_data import get_event_list

    
    parser = argparse.ArgumentParser(description=".")
    args = parser.parse_args()
    args.event_list = get_event_list()
    args.train_file = ""
    

    extension = args.train_file.split(".")[-1]
    raw_datasets = load_dataset(extension, data_files={"train": args.train_file})
    train_dataset = Clap_Onset_2_Audio_Dataset(raw_datasets["train"], args)
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=train_dataset.collate_fn)
    for batch in train_dataloader:
        import pdb; pdb.set_trace()
        idx, onset, event_info, audios, caption, onset_str = batch