Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from contextlib import contextmanager | |
| from distutils.version import LooseVersion | |
| from typing import Dict | |
| from typing import Optional | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| # from funasr_detach.layers.abs_normalize import AbsNormalize | |
| # from funasr_detach.models.base_model import FunASRModel | |
| # from funasr_detach.models.encoder.abs_encoder import AbsEncoder | |
| from funasr_detach.frontends.abs_frontend import AbsFrontend | |
| # from funasr_detach.models.preencoder.abs_preencoder import AbsPreEncoder | |
| # from funasr_detach.models.specaug.abs_specaug import AbsSpecAug | |
| from funasr_detach.train_utils.device_funcs import force_gatherable | |
| if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): | |
| from torch.cuda.amp import autocast | |
| else: | |
| # Nothing to do if torch<1.6.0 | |
| def autocast(enabled=True): | |
| yield | |
| class Data2VecPretrainModel(nn.Module): | |
| """Data2Vec Pretrain model""" | |
| def __init__( | |
| self, | |
| frontend=None, | |
| specaug=None, | |
| normalize=None, | |
| encoder=None, | |
| preencoder=None, | |
| ): | |
| super().__init__() | |
| self.frontend = frontend | |
| self.specaug = specaug | |
| self.normalize = normalize | |
| self.preencoder = preencoder | |
| self.encoder = encoder | |
| self.num_updates = 0 | |
| def forward( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: | |
| """Frontend + Encoder + Calc loss | |
| Args: | |
| speech: (Batch, Length, ...) | |
| speech_lengths: (Batch, ) | |
| """ | |
| # Check that batch_size is unified | |
| assert speech.shape[0] == speech_lengths.shape[0], ( | |
| speech.shape, | |
| speech_lengths.shape, | |
| ) | |
| self.encoder.set_num_updates(self.num_updates) | |
| # 1. Encoder | |
| encoder_out = self.encode(speech, speech_lengths) | |
| losses = encoder_out["losses"] | |
| loss = sum(losses.values()) | |
| sample_size = encoder_out["sample_size"] | |
| loss = loss.sum() / sample_size | |
| target_var = float(encoder_out["target_var"]) | |
| pred_var = float(encoder_out["pred_var"]) | |
| ema_decay = float(encoder_out["ema_decay"]) | |
| stats = dict( | |
| loss=torch.clone(loss.detach()), | |
| target_var=target_var, | |
| pred_var=pred_var, | |
| ema_decay=ema_decay, | |
| ) | |
| loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device) | |
| return loss, stats, weight | |
| def collect_feats( | |
| self, speech: torch.Tensor, speech_lengths: torch.Tensor | |
| ) -> Dict[str, torch.Tensor]: | |
| feats, feats_lengths = self._extract_feats(speech, speech_lengths) | |
| return {"feats": feats, "feats_lengths": feats_lengths} | |
| def encode( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| ): | |
| """Frontend + Encoder. | |
| Args: | |
| speech: (Batch, Length, ...) | |
| speech_lengths: (Batch, ) | |
| """ | |
| with autocast(False): | |
| # 1. Extract feats | |
| feats, feats_lengths = self._extract_feats(speech, speech_lengths) | |
| # 2. Data augmentation | |
| if self.specaug is not None and self.training: | |
| feats, feats_lengths = self.specaug(feats, feats_lengths) | |
| # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN | |
| if self.normalize is not None: | |
| feats, feats_lengths = self.normalize(feats, feats_lengths) | |
| # Pre-encoder, e.g. used for raw input data | |
| if self.preencoder is not None: | |
| feats, feats_lengths = self.preencoder(feats, feats_lengths) | |
| # 4. Forward encoder | |
| if min(speech_lengths) == max( | |
| speech_lengths | |
| ): # for clipping, set speech_lengths as None | |
| speech_lengths = None | |
| encoder_out = self.encoder( | |
| feats, speech_lengths, mask=True, features_only=False | |
| ) | |
| return encoder_out | |
| def _extract_feats( | |
| self, speech: torch.Tensor, speech_lengths: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| assert speech_lengths.dim() == 1, speech_lengths.shape | |
| # for data-parallel | |
| speech = speech[:, : speech_lengths.max()] | |
| if self.frontend is not None: | |
| # Frontend | |
| # e.g. STFT and Feature extract | |
| # data_loader may send time-domain signal in this case | |
| # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) | |
| feats, feats_lengths = self.frontend(speech, speech_lengths) | |
| else: | |
| # No frontend and no feature extract | |
| feats, feats_lengths = speech, speech_lengths | |
| return feats, feats_lengths | |
| def set_num_updates(self, num_updates): | |
| self.num_updates = num_updates | |
| def get_num_updates(self): | |
| return self.num_updates | |
