File size: 566 Bytes
2359bda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from torch import Tensor
from torch import nn
from typing import Dict
import torch.nn.functional as F
class Normalize(nn.Module):
"""
This layer normalizes embeddings to unit length
"""
def __init__(self):
super(Normalize, self).__init__()
def forward(self, features: Dict[str, Tensor]):
features.update({'sentence_embedding': F.normalize(features['sentence_embedding'], p=2, dim=1)})
return features
def save(self, output_path):
pass
@staticmethod
def load(input_path):
return Normalize()
|