Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from fairseq.models.roberta.hub_interface import RobertaHubInterface | |
import torch | |
import torch.nn.functional as F | |
class XMODHubInterface(RobertaHubInterface): | |
def extract_features( | |
self, | |
tokens: torch.LongTensor, | |
return_all_hiddens: bool = False, | |
lang_id=None, | |
) -> torch.Tensor: | |
if tokens.dim() == 1: | |
tokens = tokens.unsqueeze(0) | |
if tokens.size(-1) > self.model.max_positions(): | |
raise ValueError( | |
"tokens exceeds maximum length: {} > {}".format( | |
tokens.size(-1), self.model.max_positions() | |
) | |
) | |
features, extra = self.model( | |
tokens.to(device=self.device), | |
features_only=True, | |
return_all_hiddens=return_all_hiddens, | |
lang_id=lang_id, | |
) | |
if return_all_hiddens: | |
# convert from T x B x C -> B x T x C | |
inner_states = extra["inner_states"] | |
return [inner_state.transpose(0, 1) for inner_state in inner_states] | |
else: | |
return features # just the last layer's features | |
def predict( | |
self, | |
head: str, | |
tokens: torch.LongTensor, | |
return_logits: bool = False, | |
lang_id=None, | |
): | |
features = self.extract_features(tokens.to(device=self.device), lang_id=lang_id) | |
logits = self.model.classification_heads[head](features) | |
if return_logits: | |
return logits | |
return F.log_softmax(logits, dim=-1) | |