Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from . import SparseTensor | |
| __all__ = [ | |
| 'SparseLinear' | |
| ] | |
| class SparseLinear(nn.Linear): | |
| def __init__(self, in_features, out_features, bias=True): | |
| super(SparseLinear, self).__init__(in_features, out_features, bias) | |
| def forward(self, input: SparseTensor) -> SparseTensor: | |
| return input.replace(super().forward(input.feats)) | |