Spaces:
Build error
Build error
| # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) | |
| # | |
| # Copied from https://github.com/k2-fsa/sherpa/blob/master/sherpa/bin/conformer_rnnt/decode.py | |
| # | |
| # See LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from typing import List | |
| import torch | |
| from sherpa import RnntConformerModel, greedy_search, modified_beam_search | |
| from torch.nn.utils.rnn import pad_sequence | |
| LOG_EPS = math.log(1e-10) | |
| def run_model_and_do_greedy_search( | |
| model: RnntConformerModel, | |
| features: List[torch.Tensor], | |
| ) -> List[List[int]]: | |
| """Run RNN-T model with the given features and use greedy search | |
| to decode the output of the model. | |
| Args: | |
| model: | |
| The RNN-T model. | |
| features: | |
| A list of 2-D tensors. Each entry is of shape | |
| (num_frames, feature_dim). | |
| Returns: | |
| Return a list-of-list containing the decoding token IDs. | |
| """ | |
| features_length = torch.tensor( | |
| [f.size(0) for f in features], | |
| dtype=torch.int64, | |
| ) | |
| features = pad_sequence( | |
| features, | |
| batch_first=True, | |
| padding_value=LOG_EPS, | |
| ) | |
| device = model.device | |
| features = features.to(device) | |
| features_length = features_length.to(device) | |
| encoder_out, encoder_out_length = model.encoder( | |
| features=features, | |
| features_length=features_length, | |
| ) | |
| hyp_tokens = greedy_search( | |
| model=model, | |
| encoder_out=encoder_out, | |
| encoder_out_length=encoder_out_length.cpu(), | |
| ) | |
| return hyp_tokens | |
| def run_model_and_do_modified_beam_search( | |
| model: RnntConformerModel, | |
| features: List[torch.Tensor], | |
| num_active_paths: int, | |
| ) -> List[List[int]]: | |
| """Run RNN-T model with the given features and use greedy search | |
| to decode the output of the model. | |
| Args: | |
| model: | |
| The RNN-T model. | |
| features: | |
| A list of 2-D tensors. Each entry is of shape | |
| (num_frames, feature_dim). | |
| num_active_paths: | |
| Used only when decoding_method is modified_beam_search. | |
| It specifies number of active paths for each utterance. Due to | |
| merging paths with identical token sequences, the actual number | |
| may be less than "num_active_paths". | |
| Returns: | |
| Return a list-of-list containing the decoding token IDs. | |
| """ | |
| features_length = torch.tensor( | |
| [f.size(0) for f in features], | |
| dtype=torch.int64, | |
| ) | |
| features = pad_sequence( | |
| features, | |
| batch_first=True, | |
| padding_value=LOG_EPS, | |
| ) | |
| device = model.device | |
| features = features.to(device) | |
| features_length = features_length.to(device) | |
| encoder_out, encoder_out_length = model.encoder( | |
| features=features, | |
| features_length=features_length, | |
| ) | |
| hyp_tokens = modified_beam_search( | |
| model=model, | |
| encoder_out=encoder_out, | |
| encoder_out_length=encoder_out_length.cpu(), | |
| num_active_paths=num_active_paths, | |
| ) | |
| return hyp_tokens | |