|  |  | 
					
						
						|  | from torch import nn | 
					
						
						|  | from torch.autograd import Function | 
					
						
						|  | from torch.nn.modules.utils import _pair | 
					
						
						|  |  | 
					
						
						|  | from ..utils import ext_loader | 
					
						
						|  |  | 
					
						
						|  | ext_module = ext_loader.load_ext('_ext', | 
					
						
						|  | ['psamask_forward', 'psamask_backward']) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PSAMaskFunction(Function): | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def symbolic(g, input, psa_type, mask_size): | 
					
						
						|  | return g.op( | 
					
						
						|  | 'mmcv::MMCVPSAMask', | 
					
						
						|  | input, | 
					
						
						|  | psa_type_i=psa_type, | 
					
						
						|  | mask_size_i=mask_size) | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def forward(ctx, input, psa_type, mask_size): | 
					
						
						|  | ctx.psa_type = psa_type | 
					
						
						|  | ctx.mask_size = _pair(mask_size) | 
					
						
						|  | ctx.save_for_backward(input) | 
					
						
						|  |  | 
					
						
						|  | h_mask, w_mask = ctx.mask_size | 
					
						
						|  | batch_size, channels, h_feature, w_feature = input.size() | 
					
						
						|  | assert channels == h_mask * w_mask | 
					
						
						|  | output = input.new_zeros( | 
					
						
						|  | (batch_size, h_feature * w_feature, h_feature, w_feature)) | 
					
						
						|  |  | 
					
						
						|  | ext_module.psamask_forward( | 
					
						
						|  | input, | 
					
						
						|  | output, | 
					
						
						|  | psa_type=psa_type, | 
					
						
						|  | num_=batch_size, | 
					
						
						|  | h_feature=h_feature, | 
					
						
						|  | w_feature=w_feature, | 
					
						
						|  | h_mask=h_mask, | 
					
						
						|  | w_mask=w_mask, | 
					
						
						|  | half_h_mask=(h_mask - 1) // 2, | 
					
						
						|  | half_w_mask=(w_mask - 1) // 2) | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def backward(ctx, grad_output): | 
					
						
						|  | input = ctx.saved_tensors[0] | 
					
						
						|  | psa_type = ctx.psa_type | 
					
						
						|  | h_mask, w_mask = ctx.mask_size | 
					
						
						|  | batch_size, channels, h_feature, w_feature = input.size() | 
					
						
						|  | grad_input = grad_output.new_zeros( | 
					
						
						|  | (batch_size, channels, h_feature, w_feature)) | 
					
						
						|  | ext_module.psamask_backward( | 
					
						
						|  | grad_output, | 
					
						
						|  | grad_input, | 
					
						
						|  | psa_type=psa_type, | 
					
						
						|  | num_=batch_size, | 
					
						
						|  | h_feature=h_feature, | 
					
						
						|  | w_feature=w_feature, | 
					
						
						|  | h_mask=h_mask, | 
					
						
						|  | w_mask=w_mask, | 
					
						
						|  | half_h_mask=(h_mask - 1) // 2, | 
					
						
						|  | half_w_mask=(w_mask - 1) // 2) | 
					
						
						|  | return grad_input, None, None, None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | psa_mask = PSAMaskFunction.apply | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PSAMask(nn.Module): | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, psa_type, mask_size=None): | 
					
						
						|  | super(PSAMask, self).__init__() | 
					
						
						|  | assert psa_type in ['collect', 'distribute'] | 
					
						
						|  | if psa_type == 'collect': | 
					
						
						|  | psa_type_enum = 0 | 
					
						
						|  | else: | 
					
						
						|  | psa_type_enum = 1 | 
					
						
						|  | self.psa_type_enum = psa_type_enum | 
					
						
						|  | self.mask_size = mask_size | 
					
						
						|  | self.psa_type = psa_type | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input): | 
					
						
						|  | return psa_mask(input, self.psa_type_enum, self.mask_size) | 
					
						
						|  |  | 
					
						
						|  | def __repr__(self): | 
					
						
						|  | s = self.__class__.__name__ | 
					
						
						|  | s += f'(psa_type={self.psa_type}, ' | 
					
						
						|  | s += f'mask_size={self.mask_size})' | 
					
						
						|  | return s | 
					
						
						|  |  |