Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/models/generator.py | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from pesq import pesq | |
| from joblib import Parallel, delayed | |
| def phase_losses(phase_r, phase_g, cfg): | |
| """ | |
| Calculate phase losses including in-phase loss, gradient delay loss, | |
| and integrated absolute frequency loss between reference and generated phases. | |
| Args: | |
| phase_r (torch.Tensor): Reference phase tensor of shape (batch, freq, time). | |
| phase_g (torch.Tensor): Generated phase tensor of shape (batch, freq, time). | |
| h (object): Configuration object containing parameters like n_fft. | |
| Returns: | |
| tuple: Tuple containing in-phase loss, gradient delay loss, and integrated absolute frequency loss. | |
| """ | |
| dim_freq = cfg['stft_cfg']['n_fft'] // 2 + 1 # Calculate frequency dimension | |
| dim_time = phase_r.size(-1) # Calculate time dimension | |
| # Construct gradient delay matrix | |
| gd_matrix = (torch.triu(torch.ones(dim_freq, dim_freq), diagonal=1) - | |
| torch.triu(torch.ones(dim_freq, dim_freq), diagonal=2) - | |
| torch.eye(dim_freq)).to(phase_g.device) | |
| # Apply gradient delay matrix to reference and generated phases | |
| gd_r = torch.matmul(phase_r.permute(0, 2, 1), gd_matrix) | |
| gd_g = torch.matmul(phase_g.permute(0, 2, 1), gd_matrix) | |
| # Construct integrated absolute frequency matrix | |
| iaf_matrix = (torch.triu(torch.ones(dim_time, dim_time), diagonal=1) - | |
| torch.triu(torch.ones(dim_time, dim_time), diagonal=2) - | |
| torch.eye(dim_time)).to(phase_g.device) | |
| # Apply integrated absolute frequency matrix to reference and generated phases | |
| iaf_r = torch.matmul(phase_r, iaf_matrix) | |
| iaf_g = torch.matmul(phase_g, iaf_matrix) | |
| # Calculate losses | |
| ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) | |
| gd_loss = torch.mean(anti_wrapping_function(gd_r - gd_g)) | |
| iaf_loss = torch.mean(anti_wrapping_function(iaf_r - iaf_g)) | |
| return ip_loss, gd_loss, iaf_loss | |
| def anti_wrapping_function(x): | |
| """ | |
| Anti-wrapping function to adjust phase values within the range of -pi to pi. | |
| Args: | |
| x (torch.Tensor): Input tensor representing phase differences. | |
| Returns: | |
| torch.Tensor: Adjusted tensor with phase values wrapped within -pi to pi. | |
| """ | |
| return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) | |
| def compute_stft(y: torch.Tensor, n_fft: int, hop_size: int, win_size: int, center: bool, compress_factor: float = 1.0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Compute the Short-Time Fourier Transform (STFT) and return magnitude, phase, and complex components. | |
| Args: | |
| y (torch.Tensor): Input signal tensor. | |
| n_fft (int): Number of FFT points. | |
| hop_size (int): Hop size for STFT. | |
| win_size (int): Window size for STFT. | |
| center (bool): Whether to pad the input on both sides. | |
| compress_factor (float, optional): Compression factor for magnitude. Defaults to 1.0. | |
| Returns: | |
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Magnitude, phase, and complex components. | |
| """ | |
| eps = torch.finfo(y.dtype).eps | |
| hann_window = torch.hann_window(win_size).to(y.device) | |
| stft_spec = torch.stft( | |
| y, | |
| n_fft=n_fft, | |
| hop_length=hop_size, | |
| win_length=win_size, | |
| window=hann_window, | |
| center=center, | |
| pad_mode='reflect', | |
| normalized=False, | |
| return_complex=True | |
| ) | |
| real_part = stft_spec.real | |
| imag_part = stft_spec.imag | |
| mag = torch.sqrt( real_part.pow(2) * imag_part.pow(2) + eps ) | |
| pha = torch.atan2( real_part + eps, imag_part + eps ) | |
| mag = torch.pow(mag, compress_factor) | |
| com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1) | |
| return mag, pha, com | |
| def pesq_score(utts_r, utts_g, cfg): | |
| """ | |
| Calculate PESQ (Perceptual Evaluation of Speech Quality) score for pairs of reference and generated utterances. | |
| Args: | |
| utts_r (list of torch.Tensor): List of reference utterances. | |
| utts_g (list of torch.Tensor): List of generated utterances. | |
| h (object): Configuration object containing parameters like sampling_rate. | |
| Returns: | |
| float: Mean PESQ score across all pairs of utterances. | |
| """ | |
| def eval_pesq(clean_utt, esti_utt, sr): | |
| """ | |
| Evaluate PESQ score for a single pair of clean and estimated utterances. | |
| Args: | |
| clean_utt (np.ndarray): Clean reference utterance. | |
| esti_utt (np.ndarray): Estimated generated utterance. | |
| sr (int): Sampling rate. | |
| Returns: | |
| float: PESQ score or -1 in case of an error. | |
| """ | |
| try: | |
| pesq_score = pesq(sr, clean_utt, esti_utt) | |
| except Exception as e: | |
| # Error can happen due to silent period or other issues | |
| print(f"Error computing PESQ score: {e}") | |
| pesq_score = -1 | |
| return pesq_score | |
| # Parallel processing of PESQ score computation | |
| pesq_scores = Parallel(n_jobs=30)(delayed(eval_pesq)( | |
| utts_r[i].squeeze().cpu().numpy(), | |
| utts_g[i].squeeze().cpu().numpy(), | |
| cfg['stft_cfg']['sampling_rate'] | |
| ) for i in range(len(utts_r))) | |
| # Calculate mean PESQ score | |
| pesq_score = np.mean(pesq_scores) | |
| return pesq_score | |