|  | import torch | 
					
						
						|  | from torch_complex.tensor import ComplexTensor | 
					
						
						|  |  | 
					
						
						|  | from espnet2.enh.encoder.abs_encoder import AbsEncoder | 
					
						
						|  | from espnet2.layers.stft import Stft | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class STFTEncoder(AbsEncoder): | 
					
						
						|  | """STFT encoder for speech enhancement and separation """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | n_fft: int = 512, | 
					
						
						|  | win_length: int = None, | 
					
						
						|  | hop_length: int = 128, | 
					
						
						|  | window="hann", | 
					
						
						|  | center: bool = True, | 
					
						
						|  | normalized: bool = False, | 
					
						
						|  | onesided: bool = True, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.stft = Stft( | 
					
						
						|  | n_fft=n_fft, | 
					
						
						|  | win_length=win_length, | 
					
						
						|  | hop_length=hop_length, | 
					
						
						|  | window=window, | 
					
						
						|  | center=center, | 
					
						
						|  | normalized=normalized, | 
					
						
						|  | onesided=onesided, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self._output_dim = n_fft // 2 + 1 if onesided else n_fft | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def output_dim(self) -> int: | 
					
						
						|  | return self._output_dim | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input: torch.Tensor, ilens: torch.Tensor): | 
					
						
						|  | """Forward. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | input (torch.Tensor): mixed speech [Batch, sample] | 
					
						
						|  | ilens (torch.Tensor): input lengths [Batch] | 
					
						
						|  | Returns: | 
					
						
						|  | stft spectrum (torch.ComplexTensor):  (Batch, Frames, Freq) | 
					
						
						|  | or (Batch, Frames, Channels, Freq) | 
					
						
						|  | """ | 
					
						
						|  | spectrum, flens = self.stft(input, ilens) | 
					
						
						|  | spectrum = ComplexTensor(spectrum[..., 0], spectrum[..., 1]) | 
					
						
						|  |  | 
					
						
						|  | return spectrum, flens | 
					
						
						|  |  |