Spaces:
Runtime error
Runtime error
File size: 2,587 Bytes
e2c1e0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
import logging
from espnet.nets.pytorch_backend.backbones.conv3d_extractor import Conv3dResNet
from espnet.nets.pytorch_backend.backbones.conv1d_extractor import Conv1dResNet
class VideoEmbedding(torch.nn.Module):
"""Video Embedding
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
"""
def __init__(self, idim, odim, dropout_rate, pos_enc_class, backbone_type="resnet", relu_type="prelu"):
super(VideoEmbedding, self).__init__()
self.trunk = Conv3dResNet(
backbone_type=backbone_type,
relu_type=relu_type
)
self.out = torch.nn.Sequential(
torch.nn.Linear(idim, odim),
pos_enc_class,
)
def forward(self, x, x_mask, extract_feats=None):
"""video embedding for x
:param torch.Tensor x: input tensor
:param torch.Tensor x_mask: input mask
:param str extract_features: the position for feature extraction
:return: subsampled x and mask
:rtype Tuple[torch.Tensor, torch.Tensor]
"""
x_resnet, x_mask = self.trunk(x, x_mask)
x = self.out(x_resnet)
if extract_feats:
return x, x_mask, x_resnet
else:
return x, x_mask
class AudioEmbedding(torch.nn.Module):
"""Audio Embedding
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
"""
def __init__(self, idim, odim, dropout_rate, pos_enc_class, relu_type="prelu", a_upsample_ratio=1):
super(AudioEmbedding, self).__init__()
self.trunk = Conv1dResNet(
relu_type=relu_type,
a_upsample_ratio=a_upsample_ratio,
)
self.out = torch.nn.Sequential(
torch.nn.Linear(idim, odim),
pos_enc_class,
)
def forward(self, x, x_mask, extract_feats=None):
"""audio embedding for x
:param torch.Tensor x: input tensor
:param torch.Tensor x_mask: input mask
:param str extract_features: the position for feature extraction
:return: subsampled x and mask
:rtype Tuple[torch.Tensor, torch.Tensor]
"""
x_resnet, x_mask = self.trunk(x, x_mask)
x = self.out(x_resnet)
if extract_feats:
return x, x_mask, x_resnet
else:
return x, x_mask
|