|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from typing import Literal, Optional |
|
|
|
import logging |
|
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
class CustomSoftMax(nn.Module): |
|
def __init__( |
|
self, |
|
sfx_type: Literal['gumbel_softmax', 'softmax'], |
|
temperature: float, |
|
is_temperature_learnable: bool, |
|
is_gumbel_hard: Optional[bool]=None, |
|
*args, |
|
**kwargs, |
|
) -> None: |
|
|
|
super().__init__() |
|
self.sfx_type = sfx_type |
|
assert not is_temperature_learnable, 'is_temperature_learnable is prohibited in this version, will go to negative' |
|
self.temperature = nn.Parameter(torch.tensor([float(temperature)]),requires_grad=is_temperature_learnable) |
|
self.is_gumbel_hard = is_gumbel_hard |
|
self.args = args |
|
self.kwargs = kwargs |
|
|
|
def forward(self, x): |
|
|
|
if self.sfx_type == 'gumbel_softmax': |
|
if self.is_gumbel_hard is not None: |
|
return F.gumbel_softmax(x, tau=self.temperature, hard=self.is_gumbel_hard, dim=1) |
|
else: |
|
raise ValueError('is_gumbel_hard is not passed') |
|
elif self.sfx_type == 'softmax': |
|
return F.softmax(x/self.temperature, dim=1) |
|
else: |
|
raise NotImplementedError(f'{self.sfx_type} is not implemented yet') |
|
|
|
if __name__ == "__main__": |
|
|
|
sfx = CustomSoftMax(sfx_type='gumbel_softmax', temperature=1, is_temperature_learnable=False, is_gumbel_hard=True) |
|
x = torch.randn(10,3) |
|
print(x.shape) |
|
print(sfx(x)) |
|
|
|
sfx = CustomSoftMax(sfx_type='gumbel_softmax', temperature=1, is_temperature_learnable=True, is_gumbel_hard=True) |
|
x = torch.randn(10,3) |
|
print(x.shape) |
|
print(sfx(x)) |
|
|
|
sfx = CustomSoftMax(sfx_type='softmax', temperature=1, is_temperature_learnable=False) |
|
x = torch.randn(10,3) |
|
print(sfx(x)) |
|
|
|
sfx = CustomSoftMax(sfx_type='softmax',temperature=0.01, is_temperature_learnable=True, is_gumbel_hard=None) |
|
x = torch.randn(10,3) |
|
print(sfx(x)) |
|
|