import os import librosa import numpy as np import torch import torch.nn as nn import torch.optim as optim import torchaudio from torch.utils.data import Dataset, DataLoader from hparams import Hparams from model_cnn import Model from dataset import MyDataset args = Hparams.args device = args['device'] split = 'train' tone_class = 5 NUM_EPOCHS = 100 # num_class = len(train_loader.dataset.pinyin) * tone_class + 1 # model = Model(syllable_class = num_class) # model.to(device) def move_data_to_device(data, device): ret = [] for i in data: if isinstance(i, torch.Tensor): ret.append(i.to(device)) return ret def collate_fn(batch): # TODO inp = [] f0 = [] word = [] tone = [] max_frame_num = 1600 for sample in batch: max_frame_num = max(max_frame_num, sample[0].shape[0], sample[1].shape[0], sample[2].shape[0], sample[3].shape[0]) for sample in batch: inp.append( torch.nn.functional.pad(sample[0], (0, 0, 0, max_frame_num - sample[0].shape[0]), mode='constant', value=0)) f0.append( torch.nn.functional.pad(sample[1], (0, max_frame_num - sample[1].shape[0]), mode='constant', value=0)) word.append( torch.nn.functional.pad(sample[2], (0, 50 - sample[2].shape[0]), mode='constant', value=0)) tone.append( torch.nn.functional.pad(sample[3], (0, 50 - sample[3].shape[0]), mode='constant', value=0)) inp = torch.stack(inp) f0 = torch.stack(f0) word = torch.stack(word) tone = torch.stack(tone) return inp, f0, word, tone def get_data_loader(split, args): Dataset = MyDataset( dataset_root=args['dataset_root'], split=split, sampling_rate=args['sampling_rate'], sample_length=args['sample_length'], frame_size=args['frame_size'], ) Dataset.dataset_index=Dataset.dataset_index[:32] Dataset.index=Dataset.index[:32] data_loader = DataLoader( Dataset, batch_size=args['batch_size'], num_workers=args['num_workers'], pin_memory=True, shuffle=True, # changed into True cuz audio files recorded by same speaker are stored in the same folder collate_fn=collate_fn, ) return data_loader # train_loader = get_data_loader(split='train', args=Hparams.args) # idx2char = { idx:char for char,idx in train_loader.dataset.pinyin.items()} # def to_pinyin(num): # if num==0: # return # pinyin,tone = idx2char[(num-1)//5],(num-1)%5+1 # return pinyin,tone def process_sequence(seq): ret = [] for w in seq: if len(ret)==0 or ret[-1]!=w: ret.append(w) return ret # def train(NUM_EPOCHS = 100): # optimizer = optim.Adam(model.parameters(), lr=0.002) # criterion = nn.CrossEntropyLoss()#(ignore_index=0) # device = Hparams.args['device'] # for epoch in range(NUM_EPOCHS): # for idx, data in enumerate(train_loader): # mel, target, len_mel, len_tag = move_data_to_device(data, device) # # break # # input_length = (mel[:,:,0]!=0.0).sum(axis=1) # # print(mel.shape, f0.shape, word.shape, tone.shape) # torch.Size([8, 1600, 256]) # mel = mel.unsqueeze(1) # output = model(mel)#[32, 400, 1000] # # target[:,:len_tag].view(-1) # # output[:,:len_tag,:].view(-1, num_classes) # # output_len = input_length//4 # # move_data_to_device(output_len, Hparams.args['device']) # loss = criterion(output.view(-1, num_class), target.view(-1).long()) # optimizer.zero_grad() # loss.backward() # optimizer.step() # # if(idx%100==0): # # print(f'Epoch {epoch+1},Iteration {idx+1}, Loss: {loss.item()}') # print(f'Epoch {epoch+1}, Loss: {loss.item()}') class ASR_Model: ''' This is main class for training model and making predictions. ''' def __init__(self, device="cpu", model_path=None,pinyin_path ='pinyin.txt'): # Initialize model self.device = device self.pinyin = {} # read encoded pinyin with open(pinyin_path, 'r') as f: lines = f.readlines() i = 0 for l in lines: self.pinyin[l.replace('\n', '')] = i i += 1 self.idx2char = { idx:char for char,idx in self.pinyin.items()} num_class = 2036#len(train_loader.dataset.pinyin) * tone_class + 1 self.model = Model(syllable_class=num_class).to(self.device) self.sampling_rate = args['sampling_rate'] if model_path is not None: self.model = torch.load(model_path) print('Model loaded.') else: print('Model initialized.') self.model.to(device) def fit(self, args,NUM_EPOCHS=100): # Set paths save_model_dir = args['save_model_dir'] if not os.path.exists(save_model_dir): os.mkdir(save_model_dir) loss_fn = nn.CTCLoss() optimizer = optim.Adam(self.model.parameters(), lr=0.001) train_loader = get_data_loader(split='train', args=args) valid_loader = get_data_loader(split='train', args=args) # Start training print('Start training...') min_valid_loss = 10000 self.model.train() for epoch in range(NUM_EPOCHS): for idx, data in enumerate(train_loader): mel, f0, word, tone = move_data_to_device(data, device) input_length = (mel[:,:,0]!=0.0).sum(axis=1) # print(mel.shape) mel = mel.unsqueeze(1) # print(mel.shape) output = self.model(mel) output = output.permute(1,0,2) output_len = input_length//4 move_data_to_device(output_len, Hparams.args['device']) # print(tone.shape) target_len = (tone!=0).sum(axis=1) target = word*5+tone loss = loss_fn(output,target,output_len,target_len) optimizer.zero_grad() loss.backward() optimizer.step() if(idx%100==0): print(f'Epoch {epoch+1},Iteration {idx+1}, Loss: {loss.item()}') # Validation self.model.eval() with torch.no_grad(): losses = [] for idx, data in enumerate(valid_loader): mel, f0, word, tone = move_data_to_device(data, device) input_length = (mel[:,:,0]!=0.0).sum(axis=1) mel = mel.unsqueeze(1) out = self.model(mel) out = out.permute(1,0,2) output_len = input_length//4 move_data_to_device(output_len, Hparams.args['device']) target_len = (tone!=0).sum(axis=1) target = word*5+tone loss = loss_fn(out,target,output_len,target_len) losses.append(loss.item()) loss = np.mean(losses) # Save the best model if loss < min_valid_loss: min_valid_loss = loss target_model_path = save_model_dir + '/best_model.pth' torch.save(self.model, target_model_path) def to_pinyin(self, num): if num==0: return pinyin,tone = self.idx2char[(num-1)//5],(num-1)%5+1 return pinyin,tone def getsentence(self, words): words = words.tolist() return [self.idx2char[int(word)] for word in words] def predict(self, audio_fp): """Predict results for a given test dataset.""" waveform, sample_rate = torchaudio.load(audio_fp) waveform = torchaudio.transforms.Resample(sample_rate, self.sampling_rate)(waveform) mel_spec = torchaudio.transforms.MelSpectrogram(sample_rate=self.sampling_rate, n_fft=2048, hop_length=100, n_mels=256)(waveform) mel_spec = torch.mean(mel_spec,0) waveform, sr = librosa.load(audio_fp, sr=self.sampling_rate) f0 = torch.from_numpy(librosa.yin(waveform, fmin=50, fmax=550, hop_length=100)) mel = torch.tensor(mel_spec.T).unsqueeze(0).unsqueeze(0) # print(mel.shape) self.model.eval() with torch.no_grad(): output = self.model(mel.to(self.device)) # print(output.shape) seq = process_sequence(output[0].cpu().numpy().argmax(-1)) result = [self.to_pinyin(c) for c in seq if c!=0] return result