diff --git a/app.py b/app.py index ed9189b408434182592e1b8699b08226f5d319c7..5b36fbe8c2c8fc99015ab3c0ffa0f8093a6c66e5 100644 --- a/app.py +++ b/app.py @@ -10,9 +10,9 @@ import os import sys +import yaml sys.path.insert(0, './') -from gxl_ai_utils.utils import utils_file from wenet.utils.init_tokenizer import init_tokenizer from wenet.utils.init_model import init_model import logging @@ -20,6 +20,14 @@ import librosa import torch import torchaudio import numpy as np +def makedir_for_file(filepath): + dirpath = os.path.dirname(filepath) + if not os.path.exists(dirpath): + os.makedirs(dirpath) +def load_dict_from_yaml(file_path: str): + with open(file_path, 'rt', encoding='utf-8') as f: + dict_1 = yaml.load(f, Loader=yaml.FullLoader) + return dict_1 # 将图片转换为 Base64 with open("lab.png", "rb") as image_file: @@ -53,7 +61,7 @@ def init_model_my(): args = SimpleNamespace(**{ "checkpoint": checkpoint_path, }) - configs = utils_file.load_dict_from_yaml(config_path) + configs = load_dict_from_yaml(config_path) model, configs = init_model(args, configs) model = model.cuda() tokenizer = init_tokenizer(configs) @@ -73,7 +81,7 @@ def do_resample(input_wav_path, output_wav_path): waveform = torch.mean(waveform, dim=0, keepdim=True) waveform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=16000)(waveform) - utils_file.makedir_for_file(output_wav_path) + makedir_for_file(output_wav_path) torchaudio.save(output_wav_path, waveform, 16000) def true_decode_fuc(input_wav_path, input_prompt): diff --git a/wenet/LLM/causallm_model.py b/wenet/LLM/causallm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..adf37d71d5aa3abc3d76b10bf1546af81818c4d2 --- /dev/null +++ b/wenet/LLM/causallm_model.py @@ -0,0 +1,207 @@ +from typing import Dict, List, Optional, Union +import torch +from wenet.LLM.decoder import DecoderOnly +from wenet.LLM.sampler import sampler +from wenet.utils.common import IGNORE_ID, th_accuracy +from wenet.utils.mask import make_pad_mask, subsequent_mask + + +class CausalLM(torch.nn.Module): + + def __init__( + self, + vocab_size: int, + decoder: DecoderOnly, + special_tokens: dict, + tie_word_embedding: bool = False, + linear_bias: bool = False, + ignore_id: int = IGNORE_ID, + lsm_weight: float = 0.0, + reduction: str = 'mean', + ) -> None: + super().__init__() + del special_tokens + + self.embed = torch.nn.Embedding(vocab_size, decoder.hidden_size) + self.out = torch.nn.Linear(decoder.hidden_size, + vocab_size, + bias=linear_bias) + + self.decoder = decoder + self.vocab_size = vocab_size + self.criterion_att = torch.nn.CrossEntropyLoss( + ignore_index=ignore_id, + label_smoothing=lsm_weight, + reduction=reduction, + ) + self.tie_word_embedding = tie_word_embedding + self.ignore_id = ignore_id + + @torch.jit.unused + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + """ Forward for training + """ + text = batch['feats'].to(device) + target = batch['target'].to(device) + text_length = batch['feats_lengths'].to(device) + + mask = ~make_pad_mask(text_length, max_len=text.size(1)).unsqueeze( + 1) # (B,1,L) + causal_mask = subsequent_mask( + mask.size(-1), device=mask.device).unsqueeze(0) # (1,L,L) + att_mask = causal_mask & mask # (B, L, L) + + embeding = self.embed(text) + decoder_out = self.out(self.decoder(embeding, + att_mask)[0]) # (B, L, vocab_size) + loss = self.criterion_att(decoder_out.view(-1, self.vocab_size), + target.view(-1)) + acc = th_accuracy(decoder_out.view(-1, self.vocab_size), + target, + ignore_label=self.ignore_id) + + return { + "loss": loss, + "ppl": torch.exp(loss.detach()), + "th_accuracy": acc + } + + def tie_or_clone_weights(self, jit_mode: bool): + if not self.tie_word_embedding: + return + if jit_mode: + self.out.weight = torch.nn.Parameter(self.embed.weight.clone()) + else: + self.out.weight = self.embed.weight + # TODO(Mddct): whether to deal bias for other llm model + + @torch.jit.unused + @torch.inference_mode() + def generate( + self, + prompts_tokens: List[List[int]], + device: torch.device, + stop_tokens: List[int], + dtype: torch.dtype = torch.float32, + output_len: int = 100, + temperature: Union[float, None] = 0.95, + top_p: float = 1.0, + top_k: int = 100, + ) -> List[List[int]]: + """Generates responses for given prompts using Gemma model.""" + # If a single prompt is provided, treat it as a batch of 1. + batch_size = len(prompts_tokens) + min_prompt_len = min(len(p) for p in prompts_tokens) + max_prompt_len = max(len(p) for p in prompts_tokens) + max_seq_len = max_prompt_len + output_len + assert max_seq_len <= self.decoder.pos_enc.max_len + + # build KV caches + kv_caches = [] + for _ in range(len(self.decoder.decoders)): + size = (batch_size, 0, self.decoder.n_kv_head, + self.decoder.head_dim) + k_cache = torch.zeros(size=size, dtype=dtype, device=device) + v_cache = torch.zeros(size=size, dtype=dtype, device=device) + kv_caches.append((k_cache, v_cache)) + + # prepare inputs + token_ids_tensor = torch.full((batch_size, max_seq_len), + IGNORE_ID, + dtype=torch.int64, + device=device) + input_token_ids_tensor = torch.full((batch_size, min_prompt_len), + IGNORE_ID, + dtype=torch.int64, + device=device) + # right padding + for i, p in enumerate(prompts_tokens): + token_ids_tensor[i, :len(p)] = torch.tensor(p) + input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( + p[:min_prompt_len]) + + prompt_mask_tensor = token_ids_tensor != IGNORE_ID + input_positions_tensor = torch.arange(0, + min_prompt_len, + dtype=torch.int64).to(device) + mask_tensor = torch.ones((1, 1, max_seq_len, max_seq_len), + dtype=torch.bool) + mask_tensor = torch.tril(mask_tensor).to(device) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + att_mask = curr_mask_tensor.squeeze( + 1)[:, :min_prompt_len, :min_prompt_len] + output_positions_tensor = torch.LongTensor([min_prompt_len - 1 + ]).to(device) + temperatures_tensor = None if not temperature else torch.FloatTensor( + [temperature] * batch_size).to(device) + top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) + top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) + output_index = torch.tensor(min_prompt_len, + dtype=torch.int64).to(device) + + input_token_embeding = self.embed(input_token_ids_tensor) + offset = torch.tensor([0] * len(prompts_tokens)).to(device) + input_offset = offset + + stop_tokens_tensor = torch.tensor(stop_tokens, device=device) + # Prefill up to min_prompt_len tokens, then treat other prefill as + # decode and ignore output. + for i in range(max_seq_len - min_prompt_len): + decoder_out, kv_caches, = self.decoder( + input_token_embeding, + att_mask, + input_offset, + kv_caches, + ) + decoder_out = self.out(decoder_out) + decoder_out = decoder_out.index_select(1, output_positions_tensor) + next_token_ids = sampler( + decoder_out, + temperatures_tensor, + top_ps_tensor, + top_ks_tensor, + ) + curr_prompt_mask = prompt_mask_tensor.index_select( + 1, output_index).squeeze(dim=1) + curr_token_ids = token_ids_tensor.index_select( + 1, output_index).squeeze(dim=1) + output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, + next_token_ids).unsqueeze(dim=1) + token_ids_tensor.index_copy_(1, output_index, output_token_ids) + + input_token_ids_tensor = output_token_ids + input_token_embeding = self.embed(input_token_ids_tensor) + + input_positions_tensor = output_index.unsqueeze(dim=-1) + curr_mask_tensor = mask_tensor.index_select( + 2, input_positions_tensor) + att_mask = curr_mask_tensor.squeeze(1)[:, :output_index + + 1, :output_index + 1] + + output_positions_tensor = torch.tensor( + 0, dtype=torch.int64).to(device) + input_offset = offset + output_index.unsqueeze(-1) + output_index = output_index + 1 + + if all(torch.isin(next_token_ids, stop_tokens_tensor)): + break + + token_ids = token_ids_tensor.tolist() + results = [] + for i, tokens in enumerate(token_ids): + trimmed_output = tokens[len(prompts_tokens[i] + ):len(prompts_tokens[i]) + output_len] + for stop_token in stop_tokens: + try: + eos_index = trimmed_output.index(stop_token) + trimmed_output = trimmed_output[:eos_index] + break + except Exception: + continue + results.append(trimmed_output) + + return results diff --git a/wenet/LLM/decoder.py b/wenet/LLM/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b25ee75dd67c1cbce424568d5ef99176cb52ff8b --- /dev/null +++ b/wenet/LLM/decoder.py @@ -0,0 +1,161 @@ +from functools import partial +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint as ckpt +from wenet.transformer.attention import T_CACHE + +from wenet.transformer.encoder_layer import TransformerEncoderLayer +from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES, + WENET_ATTENTION_CLASSES, + WENET_EMB_CLASSES, WENET_MLP_CLASSES, + WENET_NORM_CLASSES) +from wenet.utils.common import mask_to_bias + + +class DecoderOnly(torch.nn.Module): + + def __init__( + self, + n_kv_head: int, + head_dim: int, + hidden_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + normalize_before: bool = True, + query_bias: bool = False, + key_bias: bool = False, + value_bias: bool = False, + mlp_bias: bool = False, + activation_type: str = "gelu", + gelu_approximate: Union[str, None] = None, + max_position_embeding: int = 8192, + mlp_type: str = 'gated', + layer_norm_type: str = 'rms_norm', + norm_eps: float = 1e-5, + rms_norm_offset: bool = True, + selfattention_layer_type: str = "rope_abs_selfattn", + use_sdpa: bool = False, + gradient_checkpointing: bool = False, + rope_theta: float = 10000.0, + rope_style: str = 'google', + scale_embed: bool = True, + ) -> None: + super().__init__() + + assert selfattention_layer_type in ['rope_abs_selfattn'] + self.pos_enc = WENET_EMB_CLASSES["rope_pos"]( + hidden_size, + head_dim, + max_len=max_position_embeding, + dropout_rate=positional_dropout_rate, + rope_theta=rope_theta, + scale=scale_embed) + if activation_type == "gelu" and gelu_approximate is not None: + activation = WENET_ACTIVATION_CLASSES['gelu']( + approximate=gelu_approximate) + else: + activation = WENET_ACTIVATION_CLASSES[activation_type]() + + mlp_class = WENET_MLP_CLASSES[mlp_type] + self.num_blocks = num_blocks + # TODO: support lora & refactor lora + self.decoders = torch.nn.ModuleList([ + TransformerEncoderLayer( + hidden_size, + WENET_ATTENTION_CLASSES[selfattention_layer_type]( + attention_heads, + hidden_size, + attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, + style=rope_style), + mlp_class(hidden_size, linear_units, dropout_rate, activation, + mlp_bias), + dropout_rate, + normalize_before, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + rms_norm_offset=rms_norm_offset, + ) for _ in range(self.num_blocks) + ]) + self.pre_norm = normalize_before + self.final_norm: Optional[torch.nn.Module] = None + if self.pre_norm: + norm_class = WENET_NORM_CLASSES[layer_norm_type] + if layer_norm_type == "rms_norm": + norm_class = partial( + norm_class, + add_unit_offset=rms_norm_offset, + ) + self.final_norm = norm_class(hidden_size, eps=norm_eps) + + self.n_kv_head = n_kv_head + self.head_dim = head_dim + self._hidden_size = hidden_size + self.use_sdpa = use_sdpa + self.gradient_checkpointing = gradient_checkpointing + + def forward( + self, + input: torch.Tensor, + att_mask: torch.Tensor, + input_position: Union[int, torch.Tensor] = 0, + kv_caches: Optional[List[T_CACHE]] = None, + ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: + xs, pos_emb = self.pos_enc(input, offset=input_position) + if self.use_sdpa: + att_mask = mask_to_bias(att_mask, xs.dtype) + + if self.gradient_checkpointing and self.training: + xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb) + else: + xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb, + kv_caches) + if self.pre_norm and self.final_norm is not None: + xs = self.final_norm(xs) + return xs, kv_caches + + def forward_layers( + self, + xs: torch.Tensor, + att_mask: torch.Tensor, + pos_emb: torch.Tensor, + kv_caches: Optional[List[T_CACHE]] = None, + ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: + if self.training: + for (i, layer) in enumerate(self.decoders): + xs, _, _, _ = layer(xs, att_mask, pos_emb) + new_kv_caches = kv_caches + else: + assert kv_caches is not None + new_kv_caches = [] + for (i, layer) in enumerate(self.decoders): + xs, _, new_kv_cache, _ = layer(xs, + att_mask, + pos_emb, + att_cache=(kv_caches[i][0], + kv_caches[i][1])) + new_kv_caches.append(new_kv_cache) + + return xs, new_kv_caches + + @torch.jit.ignore(drop=True) + def forward_layers_checkpointed(self, xs: torch.Tensor, + att_mask: torch.Tensor, + pos_emb: torch.Tensor) -> torch.Tensor: + for layer in self.decoders: + xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask, + pos_emb) + return xs + + @property + def hidden_size(self): + return self._hidden_size diff --git a/wenet/LLM/sampler.py b/wenet/LLM/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..19f0d5cdaffd11cc2faf1fdbf2e61771635efa4b --- /dev/null +++ b/wenet/LLM/sampler.py @@ -0,0 +1,43 @@ +from typing import Union +import torch + + +# modified from https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L26 +@torch.no_grad() +def sampler( + logits: torch.Tensor, + temperatures: Union[torch.Tensor, None], + top_ps: torch.Tensor, + top_ks: torch.Tensor, +) -> torch.Tensor: + assert logits.size(1) == 1 + logits = logits.squeeze(1) # (batch_size, vocab_size) + if temperatures is None: + return torch.argmax(logits, dim=-1).squeeze(dim=-1) + + # Apply temperature scaling. + logits.div_(temperatures.unsqueeze(dim=1)) + + # Calculate probabilities with softmax. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + + # Apply top-p, top-k. + probs_sum = torch.cumsum(probs_sort, dim=-1) + top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) + probs_sort = torch.where(top_ps_mask, 0, probs_sort) + + top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) + top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) + top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1) + probs_sort = torch.where(top_ks_mask, 0, probs_sort) + + # Re-normalization. + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + probs = torch.gather(probs_sort, + dim=-1, + index=torch.argsort(probs_idx, dim=-1)) + + next_token_ids = torch.multinomial(probs, num_samples=1, + replacement=True).squeeze(dim=-1) + return next_token_ids diff --git a/wenet/__init__.py b/wenet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..820ad3180b78bec00478ba4e30ce4b515967e405 --- /dev/null +++ b/wenet/__init__.py @@ -0,0 +1 @@ +from wenet.cli.model import load_model # noqa diff --git a/wenet/bin/alignment.py b/wenet/bin/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..12c272a2bd5829f35d767eac558bea3dbdffdf5f --- /dev/null +++ b/wenet/bin/alignment.py @@ -0,0 +1,268 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Di Wu) +# 2022 Tinnove Inc (authors: Wei Ren) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys + +import torch +import yaml +from torch.utils.data import DataLoader +from textgrid import TextGrid, IntervalTier +import math + +from wenet.dataset.dataset import Dataset +from wenet.utils.ctc_utils import force_align +from wenet.utils.common import get_subsample +from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer + + +def generator_textgrid(maxtime, lines, output): + # Download Praat: https://www.fon.hum.uva.nl/praat/ + interval = maxtime / (len(lines) + 1) + margin = 0.0001 + + tg = TextGrid(maxTime=maxtime) + linetier = IntervalTier(name="line", maxTime=maxtime) + + i = 0 + for l in lines: + s, e, w = l.split() + linetier.add(minTime=float(s) + margin, maxTime=float(e), mark=w) + + tg.append(linetier) + print("successfully generator {}".format(output)) + tg.write(output) + + +def get_frames_timestamp(alignment, + prob, + blank_thres=0.999, + thres=0.0000000001): + # convert alignment to a praat format, which is a doing phonetics + # by computer and helps analyzing alignment + timestamp = [] + # get frames level duration for each token + start = 0 + end = 0 + local_start = 0 + while end < len(alignment): + while end < len(alignment) and alignment[end] == 0: + end += 1 + if end == len(alignment): + timestamp[-1] += alignment[start:] + break + end += 1 + while end < len(alignment) and alignment[end - 1] == alignment[end]: + end += 1 + local_start = end - 1 + # find the possible front border for current token + while local_start >= start and ( + prob[local_start][0] < math.log(blank_thres) + or prob[local_start][alignment[end - 1]] > math.log(thres)): + alignment[local_start] = alignment[end - 1] + local_start -= 1 + cur_alignment = alignment[start:end] + timestamp.append(cur_alignment) + start = end + return timestamp + + +def get_labformat(timestamp, subsample): + begin = 0 + begin_time = 0 + duration = 0 + labformat = [] + for idx, t in enumerate(timestamp): + # 25ms frame_length,10ms hop_length, 1/subsample + subsample = get_subsample(configs) + # time duration + i = 0 + while t[i] == 0: + i += 1 + begin = i + dur = 0 + while i < len(t) and t[i] != 0: + i += 1 + dur += 1 + begin = begin_time + begin * 0.01 * subsample + duration = dur * 0.01 * subsample + if idx < len(timestamp) - 1: + print("{:.2f} {:.2f} {}".format(begin, begin + duration, + char_dict[t[-1]])) + labformat.append("{:.2f} {:.2f} {}\n".format( + begin, begin + duration, char_dict[t[-1]])) + else: # last token + non_blank = 0 + for i in t: + if i != 0: + token = i + break + print("{:.2f} {:.2f} {}".format(begin, begin + duration, + char_dict[token])) + labformat.append("{:.2f} {:.2f} {}\n".format( + begin, begin + duration, char_dict[token])) + begin_time += len(t) * 0.01 * subsample + return labformat + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='use ctc to generate alignment') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--input_file', required=True, help='format data file') + parser.add_argument('--data_type', + default='raw', + choices=['raw', 'shard'], + help='train and cv data type') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--device', + type=str, + default="cpu", + choices=["cpu", "npu", "cuda"], + help='accelerator to use') + parser.add_argument('--blank_thres', + default=0.999999, + type=float, + help='ctc blank thes') + parser.add_argument('--thres', + default=0.000001, + type=float, + help='ctc non blank thes') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--dict', required=True, help='dict file') + parser.add_argument( + '--non_lang_syms', + help="non-linguistic symbol file. One symbol per line.") + parser.add_argument('--result_file', + required=True, + help='alignment result file') + parser.add_argument('--batch_size', type=int, default=1, help='batch size') + parser.add_argument('--gen_praat', + action='store_true', + help='convert alignment to a praat format') + parser.add_argument('--bpe_model', + default=None, + type=str, + help='bpe model for english part') + + args = parser.parse_args() + print(args) + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + if args.gpu != -1: + # remain the original usage of gpu + args.device = "cuda" + if "cuda" in args.device: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + if args.batch_size > 1: + logging.fatal('alignment mode must be running with batch_size == 1') + sys.exit(1) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + # Load dict + char_dict = {} + with open(args.dict, 'r') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + char_dict[int(arr[1])] = arr[0] + eos = len(char_dict) - 1 + + # Init dataset and data loader + ali_conf = copy.deepcopy(configs['dataset_conf']) + + ali_conf['filter_conf']['max_length'] = 102400 + ali_conf['filter_conf']['min_length'] = 0 + ali_conf['filter_conf']['token_max_length'] = 102400 + ali_conf['filter_conf']['token_min_length'] = 0 + ali_conf['filter_conf']['max_output_input_ratio'] = 102400 + ali_conf['filter_conf']['min_output_input_ratio'] = 0 + ali_conf['speed_perturb'] = False + ali_conf['spec_aug'] = False + ali_conf['spec_trim'] = False + ali_conf['shuffle'] = False + ali_conf['sort'] = False + ali_conf['fbank_conf']['dither'] = 0.0 + ali_conf['batch_conf']['batch_type'] = "static" + ali_conf['batch_conf']['batch_size'] = args.batch_size + + tokenizer = init_tokenizer(configs) + ali_dataset = Dataset(args.data_type, + args.input_file, + tokenizer, + ali_conf, + partition=False) + + ali_data_loader = DataLoader(ali_dataset, batch_size=None, num_workers=0) + + # Init asr model from configs + model, configs = init_model(args, configs) + + device = torch.device(args.device) + model = model.to(device) + + model.eval() + with torch.no_grad(), open(args.result_file, 'w', + encoding='utf-8') as fout: + for batch_idx, batch in enumerate(ali_data_loader): + print("#" * 80) + key, feat, target, feats_length, target_length = batch + + feat = feat.to(device) + target = target.to(device) + feats_length = feats_length.to(device) + target_length = target_length.to(device) + # Let's assume B = batch_size and N = beam_size + # 1. Encoder + encoder_out, encoder_mask = model._forward_encoder( + feat, feats_length) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + # print(ctc_probs.size(1)) + ctc_probs = ctc_probs.squeeze(0) + target = target.squeeze(0) + alignment = force_align(ctc_probs, target) + fout.write('{} {}\n'.format(key[0], alignment)) + + if args.gen_praat: + timestamp = get_frames_timestamp(alignment, ctc_probs, + args.blank_thres, args.thres) + subsample = get_subsample(configs) + labformat = get_labformat(timestamp, subsample) + + lab_path = os.path.join(os.path.dirname(args.result_file), + key[0] + ".lab") + with open(lab_path, 'w', encoding='utf-8') as f: + f.writelines(labformat) + + textgrid_path = os.path.join(os.path.dirname(args.result_file), + key[0] + ".TextGrid") + generator_textgrid(maxtime=(len(alignment) + 1) * 0.01 * + subsample, + lines=labformat, + output=textgrid_path) diff --git a/wenet/bin/average_model.py b/wenet/bin/average_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1e97b059ce54b3fc0c27099c030239198f7ebf9d --- /dev/null +++ b/wenet/bin/average_model.py @@ -0,0 +1,125 @@ +# Copyright (c) 2020 Mobvoi Inc (Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import glob +import sys + +import yaml +import torch + + +def get_args(): + parser = argparse.ArgumentParser(description='average model') + parser.add_argument('--dst_model', required=True, help='averaged model') + parser.add_argument('--src_path', + required=True, + help='src model path for average') + parser.add_argument('--val_best', + action="store_true", + help='averaged model') + parser.add_argument('--num', + default=5, + type=int, + help='nums for averaged model') + parser.add_argument('--min_epoch', + default=0, + type=int, + help='min epoch used for averaging model') + parser.add_argument('--max_epoch', + default=sys.maxsize, + type=int, + help='max epoch used for averaging model') + parser.add_argument('--min_step', + default=0, + type=int, + help='min step used for averaging model') + parser.add_argument('--max_step', + default=sys.maxsize, + type=int, + help='max step used for averaging model') + parser.add_argument('--mode', + default="hybrid", + choices=["hybrid", "epoch", "step"], + type=str, + help='average mode') + + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + checkpoints = [] + val_scores = [] + if args.val_best: + if args.mode == "hybrid": + yamls = glob.glob('{}/*.yaml'.format(args.src_path)) + yamls = [ + f for f in yamls + if not (os.path.basename(f).startswith('train') + or os.path.basename(f).startswith('init')) + ] + elif args.mode == "step": + yamls = glob.glob('{}/step_*.yaml'.format(args.src_path)) + else: + yamls = glob.glob('{}/epoch_*.yaml'.format(args.src_path)) + for y in yamls: + with open(y, 'r') as f: + dic_yaml = yaml.load(f, Loader=yaml.FullLoader) + loss = dic_yaml['loss_dict']['loss'] + epoch = dic_yaml['epoch'] + step = dic_yaml['step'] + tag = dic_yaml['tag'] + if epoch >= args.min_epoch and epoch <= args.max_epoch \ + and step >= args.min_step and step <= args.max_step: + val_scores += [[epoch, step, loss, tag]] + sorted_val_scores = sorted(val_scores, + key=lambda x: x[2], + reverse=False) + print("best val (epoch, step, loss, tag) = " + + str(sorted_val_scores[:args.num])) + path_list = [ + args.src_path + '/{}.pt'.format(score[-1]) + for score in sorted_val_scores[:args.num] + ] + else: + path_list = glob.glob('{}/[!init]*.pt'.format(args.src_path)) + path_list = sorted(path_list, key=os.path.getmtime) + path_list = path_list[-args.num:] + print(path_list) + avg = {} + num = args.num + assert num == len(path_list) + for path in path_list: + print('Processing {}'.format(path)) + states = torch.load(path, map_location=torch.device('cpu')) + for k in states.keys(): + if k not in avg.keys(): + avg[k] = states[k].clone() + else: + avg[k] += states[k] + # average + for k in avg.keys(): + if avg[k] is not None: + # pytorch 1.6 use true_divide instead of /= + avg[k] = torch.true_divide(avg[k], num) + print('Saving to {}'.format(args.dst_model)) + torch.save(avg, args.dst_model) + + +if __name__ == '__main__': + main() diff --git a/wenet/bin/export_ipex.py b/wenet/bin/export_ipex.py new file mode 100644 index 0000000000000000000000000000000000000000..1d3ff181cf5233932ed5641abc231407c815c2d5 --- /dev/null +++ b/wenet/bin/export_ipex.py @@ -0,0 +1,95 @@ +# Copyright (C) 2021-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import print_function + +import argparse +import logging +import os + +import torch +import yaml + +from wenet.utils.init_model import init_model +import intel_extension_for_pytorch as ipex +from intel_extension_for_pytorch.quantization import prepare, convert + + +def get_args(): + parser = argparse.ArgumentParser(description='export your script model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--output_file', default=None, help='output file') + parser.add_argument('--dtype', + default="fp32", + help='choose the dtype to run:[fp32,bf16]') + parser.add_argument('--output_quant_file', + default=None, + help='output quantized model file') + args = parser.parse_args() + return args + + +def scripting(model): + with torch.inference_mode(): + script_model = torch.jit.script(model) + script_model = torch.jit.freeze( + script_model, + preserved_attrs=[ + "forward_encoder_chunk", "ctc_activation", + "forward_attention_decoder", "subsampling_rate", + "right_context", "sos_symbol", "eos_symbol", + "is_bidirectional_decoder" + ]) + return script_model + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + # No need gpu for model export + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + model, configs = init_model(args, configs) + print(model) + + # Apply IPEX optimization + model.eval() + torch._C._jit_set_texpr_fuser_enabled(False) + model.to(memory_format=torch.channels_last) + if args.dtype == "fp32": + ipex_model = ipex.optimize(model) + elif args.dtype == "bf16": # For Intel 4th generation Xeon (SPR) + ipex_model = ipex.optimize(model, + dtype=torch.bfloat16, + weights_prepack=False) + + # Export jit torch script model + if args.output_file: + if args.dtype == "fp32": + script_model = scripting(ipex_model) + elif args.dtype == "bf16": + torch._C._jit_set_autocast_mode(True) + with torch.cpu.amp.autocast(): + script_model = scripting(ipex_model) + script_model.save(args.output_file) + print('Export model successfully, see {}'.format(args.output_file)) + + # Export quantized jit torch script model + if args.output_quant_file: + dynamic_qconfig = ipex.quantization.default_dynamic_qconfig + dummy_data = (torch.zeros(1, 67, 80), 16, -16, + torch.zeros(12, 4, 32, 128), torch.zeros(12, 1, 256, 7)) + model = prepare(model, dynamic_qconfig, dummy_data) + model = convert(model) + script_quant_model = scripting(model) + script_quant_model.save(args.output_quant_file) + print('Export quantized model successfully, ' + 'see {}'.format(args.output_quant_file)) + + +if __name__ == '__main__': + main() diff --git a/wenet/bin/export_jit.py b/wenet/bin/export_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..98eadd61cb813c9deb97e8351d17d83789e389a9 --- /dev/null +++ b/wenet/bin/export_jit.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging +import os + +import torch +import yaml + +from wenet.utils.init_model import init_model + + +def get_args(): + parser = argparse.ArgumentParser(description='export your script model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--output_file', default=None, help='output file') + parser.add_argument('--output_quant_file', + default=None, + help='output quantized model file') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + args.jit = True + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + # No need gpu for model export + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + model, configs = init_model(args, configs) + model.eval() + print(model) + # Export jit torch script model + + if args.output_file: + script_model = torch.jit.script(model) + script_model.save(args.output_file) + print('Export model successfully, see {}'.format(args.output_file)) + + # Export quantized jit torch script model + if args.output_quant_file: + quantized_model = torch.quantization.quantize_dynamic( + model, {torch.nn.Linear}, dtype=torch.qint8) + print(quantized_model) + script_quant_model = torch.jit.script(quantized_model) + script_quant_model.save(args.output_quant_file) + print('Export quantized model successfully, ' + 'see {}'.format(args.output_quant_file)) + + +if __name__ == '__main__': + main() diff --git a/wenet/bin/export_onnx_bpu.py b/wenet/bin/export_onnx_bpu.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d93a022e865afbf424faabd552404b7bcdd4ef --- /dev/null +++ b/wenet/bin/export_onnx_bpu.py @@ -0,0 +1,1065 @@ +# Copyright (c) 2022, Horizon Inc. Xingchen Song (sxc19@tsinghua.org.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NOTE(xcsong): Currently, we only support +1. specific conformer encoder architecture, see: + encoder: conformer + encoder_conf: + activation_type: **must be** relu + attention_heads: 2 or 4 or 8 or any number divisible by output_size + causal: **must be** true + cnn_module_kernel: 1 ~ 7 + cnn_module_norm: **must be** batch_norm + input_layer: **must be** conv2d8 + linear_units: 1 ~ 2048 + normalize_before: **must be** true + num_blocks: 1 ~ 12 + output_size: 1 ~ 512 + pos_enc_layer_type: **must be** no_pos + selfattention_layer_type: **must be** selfattn + use_cnn_module: **must be** true + use_dynamic_chunk: **must be** true + use_dynamic_left_chunk: **must be** true + +2. specific decoding method: ctc_greedy_search +""" + +from __future__ import print_function + +import os +import sys +import copy +import math +import yaml +import logging +from typing import Tuple + +import torch +import numpy as np + +from wenet.transformer.embedding import NoPositionalEncoding +from wenet.utils.init_model import init_model +from wenet.bin.export_onnx_cpu import (get_args, to_numpy, + print_input_output_info) + +try: + import onnx + import onnxruntime +except ImportError: + print('Please install onnx and onnxruntime!') + sys.exit(1) + +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) + + +class BPULayerNorm(torch.nn.Module): + """Refactor torch.nn.LayerNorm to meet 4-D dataflow.""" + + def __init__(self, module, chunk_size=8, run_on_bpu=False): + super().__init__() + original = copy.deepcopy(module) + self.hidden = module.weight.size(0) + self.chunk_size = chunk_size + self.run_on_bpu = run_on_bpu + + if self.run_on_bpu: + self.weight = torch.nn.Parameter( + module.weight.reshape(1, self.hidden, 1, + 1).repeat(1, 1, 1, chunk_size)) + self.bias = torch.nn.Parameter( + module.bias.reshape(1, self.hidden, 1, + 1).repeat(1, 1, 1, chunk_size)) + self.negtive = torch.nn.Parameter( + torch.ones((1, self.hidden, 1, chunk_size)) * -1.0) + self.eps = torch.nn.Parameter( + torch.zeros((1, self.hidden, 1, chunk_size)) + module.eps) + self.mean_conv_1 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) + self.mean_conv_1.weight = torch.nn.Parameter( + torch.ones(self.hidden, self.hidden, 1, 1) / + (1.0 * self.hidden)) + self.mean_conv_2 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False) + self.mean_conv_2.weight = torch.nn.Parameter( + torch.ones(self.hidden, self.hidden, 1, 1) / + (1.0 * self.hidden)) + else: + self.norm = module + + self.check_equal(original) + + def check_equal(self, module): + random_data = torch.randn(1, self.chunk_size, self.hidden) + orig_out = module(random_data) + new_out = self.forward(random_data.transpose(1, 2).unsqueeze(2)) + np.testing.assert_allclose(to_numpy(orig_out), + to_numpy( + new_out.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.run_on_bpu: + u = self.mean_conv_1(x) # (1, h, 1, c) + numerator = x + u * self.negtive # (1, h, 1, c) + s = torch.pow(numerator, 2) # (1, h, 1, c) + s = self.mean_conv_2(s) # (1, h, 1, c) + denominator = torch.sqrt(s + self.eps) # (1, h, 1, c) + x = torch.div(numerator, denominator) # (1, h, 1, c) + x = x * self.weight + self.bias + else: + x = x.squeeze(2).transpose(1, 2).contiguous() + x = self.norm(x) + x = x.transpose(1, 2).contiguous().unsqueeze(2) + return x + + +class BPUIdentity(torch.nn.Module): + """Refactor torch.nn.Identity(). + For inserting BPU node whose input == output. + """ + + def __init__(self, channels): + super().__init__() + self.channels = channels + self.identity_conv = torch.nn.Conv2d(channels, + channels, + 1, + groups=channels, + bias=False) + torch.nn.init.dirac_(self.identity_conv.weight.data, groups=channels) + + self.check_equal() + + def check_equal(self): + random_data = torch.randn(1, self.channels, 1, 10) + result = self.forward(random_data) + np.testing.assert_allclose(to_numpy(random_data), + to_numpy(result), + rtol=1e-02, + atol=1e-03) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Identity with 4-D dataflow, input == output. + Args: + x (torch.Tensor): (batch, in_channel, 1, time) + + Returns: + (torch.Tensor): (batch, in_channel, 1, time). + """ + return self.identity_conv(x) + + +class BPULinear(torch.nn.Module): + """Refactor torch.nn.Linear or pointwise_conv""" + + def __init__(self, module, is_pointwise_conv=False): + super().__init__() + # Unchanged submodules and attributes + original = copy.deepcopy(module) + self.idim = module.weight.size(1) + self.odim = module.weight.size(0) + self.is_pointwise_conv = is_pointwise_conv + + # Modify weight & bias + self.linear = torch.nn.Conv2d(self.idim, self.odim, 1, 1) + if is_pointwise_conv: + # (odim, idim, kernel=1) -> (odim, idim, 1, 1) + self.linear.weight = torch.nn.Parameter( + module.weight.unsqueeze(-1)) + else: + # (odim, idim) -> (odim, idim, 1, 1) + self.linear.weight = torch.nn.Parameter( + module.weight.unsqueeze(2).unsqueeze(3)) + self.linear.bias = module.bias + + self.check_equal(original) + + def check_equal(self, module): + random_data = torch.randn(1, 8, self.idim) + if self.is_pointwise_conv: + random_data = random_data.transpose(1, 2) + original_result = module(random_data) + if self.is_pointwise_conv: + random_data = random_data.transpose(1, 2) + original_result = original_result.transpose(1, 2) + random_data = random_data.transpose(1, 2).unsqueeze(2) + new_result = self.forward(random_data) + np.testing.assert_allclose(to_numpy(original_result), + to_numpy( + new_result.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Linear with 4-D dataflow. + Args: + x (torch.Tensor): (batch, in_channel, 1, time) + Returns: + (torch.Tensor): (batch, out_channel, 1, time). + """ + return self.linear(x) + + +class BPUGlobalCMVN(torch.nn.Module): + """Refactor wenet/transformer/cmvn.py::GlobalCMVN""" + + def __init__(self, module): + super().__init__() + # Unchanged submodules and attributes + self.norm_var = module.norm_var + + # NOTE(xcsong): Expand to 4-D tensor, (mel_dim) -> (1, 1, mel_dim, 1) + self.mean = module.mean.unsqueeze(-1).unsqueeze(0).unsqueeze(0) + self.istd = module.istd.unsqueeze(-1).unsqueeze(0).unsqueeze(0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """CMVN with 4-D dataflow. + Args: + x (torch.Tensor): (batch, 1, mel_dim, time) + Returns: + (torch.Tensor): normalized feature with same shape. + """ + x = x - self.mean + if self.norm_var: + x = x * self.istd + return x + + +class BPUConv2dSubsampling8(torch.nn.Module): + """Refactor wenet/transformer/subsampling.py::Conv2dSubsampling8 + + NOTE(xcsong): Only support pos_enc_class == NoPositionalEncoding + """ + + def __init__(self, module): + super().__init__() + # Unchanged submodules and attributes + original = copy.deepcopy(module) + self.right_context = module.right_context + self.subsampling_rate = module.subsampling_rate + assert isinstance(module.pos_enc, NoPositionalEncoding) + + # 1. Modify self.conv + # NOTE(xcsong): We change input shape from (1, 1, frames, mel_dim) + # to (1, 1, mel_dim, frames) for more efficient computation. + self.conv = module.conv + for idx in [0, 2, 4]: + self.conv[idx].weight = torch.nn.Parameter( + module.conv[idx].weight.transpose(2, 3)) + + # 2. Modify self.linear + # NOTE(xcsong): Split final projection to meet the requirment of + # maximum kernel_size (7 for XJ3) + self.linear = torch.nn.ModuleList() + odim = module.linear.weight.size(0) # 512, in this case + freq = module.linear.weight.size(1) // odim # 4608 // 512 == 9 + self.odim, self.freq = odim, freq + weight = module.linear.weight.reshape( + odim, odim, freq, + 1) # (odim, odim * freq) -> (odim, odim, freq, 1) + self.split_size = [] + num_split = (freq - 1) // 7 + 1 # XJ3 requires kernel_size <= 7 + slice_begin = 0 + for idx in range(num_split): + kernel_size = min(freq, (idx + 1) * 7) - idx * 7 + conv_ele = torch.nn.Conv2d(odim, odim, (kernel_size, 1), + (kernel_size, 1)) + conv_ele.weight = torch.nn.Parameter( + weight[:, :, slice_begin:slice_begin + kernel_size, :]) + conv_ele.bias = torch.nn.Parameter(torch.zeros_like(conv_ele.bias)) + self.linear.append(conv_ele) + self.split_size.append(kernel_size) + slice_begin += kernel_size + self.linear[0].bias = torch.nn.Parameter(module.linear.bias) + + self.check_equal(original) + + def check_equal(self, module): + random_data = torch.randn(1, 67, 80) + mask = torch.zeros(1, 1, 67) + original_result, _, _ = module(random_data, mask) # (1, 8, 512) + random_data = random_data.transpose(1, + 2).unsqueeze(0) # (1, 1, 80, 67) + new_result = self.forward(random_data) # (1, 512, 1, 8) + np.testing.assert_allclose(to_numpy(original_result), + to_numpy( + new_result.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x with 4-D dataflow. + Args: + x (torch.Tensor): Input tensor (#batch, 1, mel_dim, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, odim, 1, time'), + where time' = time // 8. + """ + x = self.conv(x) # (1, odim, freq, time') + x_out = torch.zeros(x.size(0), self.odim, 1, x.size(3)) + x = torch.split(x, self.split_size, dim=2) + for idx, (x_part, layer) in enumerate(zip(x, self.linear)): + x_out += layer(x_part) + return x_out + + +class BPUMultiHeadedAttention(torch.nn.Module): + """Refactor wenet/transformer/attention.py::MultiHeadedAttention + + NOTE(xcsong): Only support attention_class == MultiHeadedAttention, + we do not consider RelPositionMultiHeadedAttention currently. + """ + + def __init__(self, module, chunk_size, left_chunks): + super().__init__() + # Unchanged submodules and attributes + original = copy.deepcopy(module) + self.d_k = module.d_k + self.h = module.h + n_feat = self.d_k * self.h + self.chunk_size = chunk_size + self.left_chunks = left_chunks + self.time = chunk_size * (left_chunks + 1) + self.activation = torch.nn.Softmax(dim=-1) + + # 1. Modify self.linear_x + self.linear_q = BPULinear(module.linear_q) + self.linear_k = BPULinear(module.linear_k) + self.linear_v = BPULinear(module.linear_v) + self.linear_out = BPULinear(module.linear_out) + # 2. denom + self.register_buffer( + "denom", torch.full((1, self.h, 1, 1), 1.0 / math.sqrt(self.d_k))) + + self.check_equal(original) + + def check_equal(self, module): + random_data = torch.randn(1, self.chunk_size, self.d_k * self.h) + mask = torch.ones((1, self.h, self.chunk_size, self.time), + dtype=torch.bool) + cache = torch.zeros(1, self.h, self.chunk_size * self.left_chunks, + self.d_k * 2) + original_out, original_cache = module(random_data, random_data, + random_data, mask[:, 0, :, :], + torch.empty(0), cache) + random_data = random_data.transpose(1, 2).unsqueeze(2) + cache = cache.reshape(1, self.h, self.d_k * 2, + self.chunk_size * self.left_chunks) + new_out, new_cache = self.forward(random_data, random_data, + random_data, mask, cache) + np.testing.assert_allclose(to_numpy(original_out), + to_numpy( + new_out.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + np.testing.assert_allclose(to_numpy(original_cache), + to_numpy(new_cache.transpose(2, 3)), + rtol=1e-02, + atol=1e-03) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: torch.Tensor, + cache: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot product attention. + + Args: + q (torch.Tensor): Query tensor (#batch, size, 1, chunk_size). + k (torch.Tensor): Key tensor (#batch, size, 1, chunk_size). + v (torch.Tensor): Value tensor (#batch, size, 1, chunk_size). + mask (torch.Tensor): Mask tensor, + (#batch, head, chunk_size, cache_t + chunk_size). + cache (torch.Tensor): Cache tensor + (1, head, d_k * 2, cache_t), + where `cache_t == chunk_size * left_chunks`. + + + Returns: + torch.Tensor: Output tensor (#batch, size, 1, chunk_size). + torch.Tensor: Cache tensor + (1, head, d_k * 2, cache_t + chunk_size) + where `cache_t == chunk_size * left_chunks` + """ + # 1. Forward QKV + q = self.linear_q(q) # (1, d, 1, c) d == size, c == chunk_size + k = self.linear_k(k) # (1, d, 1, c) + v = self.linear_v(v) # (1, d, 1, c) + q = q.view(1, self.h, self.d_k, self.chunk_size) + k = k.view(1, self.h, self.d_k, self.chunk_size) + v = v.view(1, self.h, self.d_k, self.chunk_size) + q = q.transpose(2, 3) # (batch, head, time1, d_k) + k_cache, v_cache = torch.split(cache, cache.size(2) // 2, dim=2) + k = torch.cat((k_cache, k), dim=3) + v = torch.cat((v_cache, v), dim=3) + new_cache = torch.cat((k, v), dim=2) + # 2. (Q^T)K + scores = torch.matmul(q, k) * self.denom # (#b, n_head, time1, time2) + # 3. Forward attention + mask = mask.eq(0) + scores = scores.masked_fill(mask, -float('inf')) + attn = self.activation(scores).masked_fill(mask, 0.0) + attn = attn.transpose(2, 3) + x = torch.matmul(v, attn) + x = x.view(1, self.d_k * self.h, 1, self.chunk_size) + x_out = self.linear_out(x) + return x_out, new_cache + + +class BPUConvolution(torch.nn.Module): + """Refactor wenet/transformer/convolution.py::ConvolutionModule + + NOTE(xcsong): Only suport use_layer_norm == False + """ + + def __init__(self, module): + super().__init__() + # Unchanged submodules and attributes + original = copy.deepcopy(module) + self.lorder = module.lorder + self.use_layer_norm = False + self.activation = module.activation + channels = module.pointwise_conv1.weight.size(1) + self.channels = channels + kernel_size = module.depthwise_conv.weight.size(2) + assert module.use_layer_norm is False + + # 1. Modify self.pointwise_conv1 + self.pointwise_conv1 = BPULinear(module.pointwise_conv1, True) + + # 2. Modify self.depthwise_conv + self.depthwise_conv = torch.nn.Conv2d(channels, + channels, (1, kernel_size), + stride=1, + groups=channels) + self.depthwise_conv.weight = torch.nn.Parameter( + module.depthwise_conv.weight.unsqueeze(-2)) + self.depthwise_conv.bias = torch.nn.Parameter( + module.depthwise_conv.bias) + + # 3. Modify self.norm, Only support batchnorm2d + self.norm = torch.nn.BatchNorm2d(channels) + self.norm.training = False + self.norm.num_features = module.norm.num_features + self.norm.eps = module.norm.eps + self.norm.momentum = module.norm.momentum + self.norm.weight = torch.nn.Parameter(module.norm.weight) + self.norm.bias = torch.nn.Parameter(module.norm.bias) + self.norm.running_mean = module.norm.running_mean + self.norm.running_var = module.norm.running_var + + # 4. Modify self.pointwise_conv2 + self.pointwise_conv2 = BPULinear(module.pointwise_conv2, True) + + # 5. Identity conv, for running `concat` on BPU + self.identity = BPUIdentity(channels) + + self.check_equal(original) + + def check_equal(self, module): + random_data = torch.randn(1, 8, self.channels) + cache = torch.zeros((1, self.channels, self.lorder)) + original_out, original_cache = module(random_data, cache=cache) + random_data = random_data.transpose(1, 2).unsqueeze(2) + cache = cache.unsqueeze(2) + new_out, new_cache = self.forward(random_data, cache) + np.testing.assert_allclose(to_numpy(original_out), + to_numpy( + new_out.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + np.testing.assert_allclose(to_numpy(original_cache), + to_numpy(new_cache.squeeze(2)), + rtol=1e-02, + atol=1e-03) + + def forward(self, x: torch.Tensor, + cache: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, channels, 1, chunk_size). + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, 1, cache_t). + Returns: + torch.Tensor: Output tensor (#batch, channels, 1, chunk_size). + torch.Tensor: Cache tensor (#batch, channels, 1, cache_t). + """ + # Concat cache + x = torch.cat((self.identity(cache), self.identity(x)), dim=3) + new_cache = x[:, :, :, -self.lorder:] + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, 1, dim) + x = torch.nn.functional.glu(x, dim=1) # (b, channel, 1, dim) + + # Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + x = self.pointwise_conv2(x) + return x, new_cache + + +class BPUFFN(torch.nn.Module): + """Refactor wenet/transformer/positionwise_feed_forward.py::PositionwiseFeedForward + """ + + def __init__(self, module): + super().__init__() + # Unchanged submodules and attributes + original = copy.deepcopy(module) + self.activation = module.activation + + # 1. Modify self.w_x + self.w_1 = BPULinear(module.w_1) + self.w_2 = BPULinear(module.w_2) + + self.check_equal(original) + + def check_equal(self, module): + random_data = torch.randn(1, 8, self.w_1.idim) + original_out = module(random_data) + random_data = random_data.transpose(1, 2).unsqueeze(2) + new_out = self.forward(random_data) + np.testing.assert_allclose(to_numpy(original_out), + to_numpy( + new_out.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + xs: input tensor (B, D, 1, L) + Returns: + output tensor, (B, D, 1, L) + """ + return self.w_2(self.activation(self.w_1(x))) + + +class BPUConformerEncoderLayer(torch.nn.Module): + """Refactor wenet/transformer/encoder_layer.py::ConformerEncoderLayer + """ + + def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): + super().__init__() + # Unchanged submodules and attributes + original = copy.deepcopy(module) + self.size = module.size + assert module.normalize_before is True + assert module.concat_after is False + + # 1. Modify submodules + self.feed_forward_macaron = BPUFFN(module.feed_forward_macaron) + self.self_attn = BPUMultiHeadedAttention(module.self_attn, chunk_size, + left_chunks) + self.conv_module = BPUConvolution(module.conv_module) + self.feed_forward = BPUFFN(module.feed_forward) + + # 2. Modify norms + self.norm_ff = BPULayerNorm(module.norm_ff, chunk_size, ln_run_on_bpu) + self.norm_mha = BPULayerNorm(module.norm_mha, chunk_size, + ln_run_on_bpu) + self.norm_ff_macron = BPULayerNorm(module.norm_ff_macaron, chunk_size, + ln_run_on_bpu) + self.norm_conv = BPULayerNorm(module.norm_conv, chunk_size, + ln_run_on_bpu) + self.norm_final = BPULayerNorm(module.norm_final, chunk_size, + ln_run_on_bpu) + + # 3. 4-D ff_scale + self.register_buffer("ff_scale", + torch.full((1, self.size, 1, 1), module.ff_scale)) + + self.check_equal(original) + + def check_equal(self, module): + time1 = self.self_attn.chunk_size + time2 = self.self_attn.time + h, d_k = self.self_attn.h, self.self_attn.d_k + random_x = torch.randn(1, time1, self.size) + att_mask = torch.ones(1, h, time1, time2) + att_cache = torch.zeros(1, h, time2 - time1, d_k * 2) + cnn_cache = torch.zeros(1, self.size, self.conv_module.lorder) + original_x, _, original_att_cache, original_cnn_cache = module( + random_x, + att_mask[:, 0, :, :], + torch.empty(0), + att_cache=att_cache, + cnn_cache=cnn_cache) + random_x = random_x.transpose(1, 2).unsqueeze(2) + att_cache = att_cache.reshape(1, h, d_k * 2, time2 - time1) + cnn_cache = cnn_cache.unsqueeze(2) + new_x, new_att_cache, new_cnn_cache = self.forward( + random_x, att_mask, att_cache, cnn_cache) + np.testing.assert_allclose(to_numpy(original_att_cache), + to_numpy(new_att_cache.transpose(2, 3)), + rtol=1e-02, + atol=1e-03) + np.testing.assert_allclose(to_numpy(original_x), + to_numpy(new_x.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + np.testing.assert_allclose(to_numpy(original_cnn_cache), + to_numpy(new_cnn_cache.squeeze(2)), + rtol=1e-02, + atol=1e-03) + + def forward( + self, x: torch.Tensor, att_mask: torch.Tensor, att_cache: torch.Tensor, + cnn_cache: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, size, 1, chunk_size) + att_mask (torch.Tensor): Mask tensor for the input + (#batch, head, chunk_size, cache_t1 + chunk_size), + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, d_k * 2, cache_t1), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, 1, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, size, 1, chunk_size). + torch.Tensor: att_cache tensor, + (1, head, d_k * 2, cache_t1 + chunk_size). + torch.Tensor: cnn_cahce tensor (#batch, size, 1, cache_t2). + """ + # 1. ffn_macaron + residual = x + x = self.norm_ff_macron(x) + x = residual + self.ff_scale * self.feed_forward_macaron(x) + + # 2. attention + residual = x + x = self.norm_mha(x) + x_att, new_att_cache = self.self_attn(x, x, x, att_mask, att_cache) + x = residual + x_att + + # 3. convolution + residual = x + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, cnn_cache) + x = residual + x + + # 4. ffn + residual = x + x = self.norm_ff(x) + x = residual + self.ff_scale * self.feed_forward(x) + + # 5. final post-norm + x = self.norm_final(x) + + return x, new_att_cache, new_cnn_cache + + +class BPUConformerEncoder(torch.nn.Module): + """Refactor wenet/transformer/encoder.py::ConformerEncoder + """ + + def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False): + super().__init__() + # Unchanged submodules and attributes + original = copy.deepcopy(module) + output_size = module.output_size() + self._output_size = module.output_size() + self.after_norm = module.after_norm + self.chunk_size = chunk_size + self.left_chunks = left_chunks + self.head = module.encoders[0].self_attn.h + self.layers = len(module.encoders) + + # 1. Modify submodules + self.global_cmvn = BPUGlobalCMVN(module.global_cmvn) + self.embed = BPUConv2dSubsampling8(module.embed) + self.encoders = torch.nn.ModuleList() + for layer in module.encoders: + self.encoders.append( + BPUConformerEncoderLayer(layer, chunk_size, left_chunks, + ln_run_on_bpu)) + + # 2. Auxiliary conv + self.identity_cnncache = BPUIdentity(output_size) + + self.check_equal(original) + + def check_equal(self, module): + time1 = self.encoders[0].self_attn.chunk_size + time2 = self.encoders[0].self_attn.time + layers = self.layers + h, d_k = self.head, self.encoders[0].self_attn.d_k + decoding_window = (self.chunk_size - 1) * \ + module.embed.subsampling_rate + \ + module.embed.right_context + 1 + lorder = self.encoders[0].conv_module.lorder + random_x = torch.randn(1, decoding_window, 80) + att_mask = torch.ones(1, h, time1, time2) + att_cache = torch.zeros(layers, h, time2 - time1, d_k * 2) + cnn_cache = torch.zeros(layers, 1, self._output_size, lorder) + orig_x, orig_att_cache, orig_cnn_cache = module.forward_chunk( + random_x, + 0, + time2 - time1, + att_mask=att_mask[:, 0, :, :], + att_cache=att_cache, + cnn_cache=cnn_cache) + random_x = random_x.unsqueeze(0) + att_cache = att_cache.reshape(1, h * layers, d_k * 2, time2 - time1) + cnn_cache = cnn_cache.reshape(1, self._output_size, layers, lorder) + new_x, new_att_cache, new_cnn_cache = self.forward( + random_x, att_cache, cnn_cache, att_mask) + caches = torch.split(new_att_cache, h, dim=1) + caches = [c.transpose(2, 3) for c in caches] + np.testing.assert_allclose(to_numpy(orig_att_cache), + to_numpy(torch.cat(caches, dim=0)), + rtol=1e-02, + atol=1e-03) + np.testing.assert_allclose(to_numpy(orig_x), + to_numpy(new_x.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + np.testing.assert_allclose( + to_numpy(orig_cnn_cache), + to_numpy(new_cnn_cache.transpose(0, 2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + + def forward( + self, xs: torch.Tensor, att_cache: torch.Tensor, + cnn_cache: torch.Tensor, att_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ Forward just one chunk + + Args: + xs (torch.Tensor): chunk input, with shape (b=1, 1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (1, head * elayers, d_k * 2, cache_t1), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (1, hidden-dim, elayers, cache_t2), where + `cache_t2 == cnn.lorder - 1` + att_mask (torch.Tensor): Mask tensor for the input + (#batch, head, chunk_size, cache_t1 + chunk_size), + + Returns: + torch.Tensor: output of current input xs, + with shape (b=1, hidden-dim, 1, chunk_size). + torch.Tensor: new attention cache required for next chunk, with + same shape as the original att_cache. + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + """ + # xs: (B, 1, time, mel_dim) -> (B, 1, mel_dim, time) + xs = xs.transpose(2, 3) + xs = self.global_cmvn(xs) + # xs: (B, 1, mel_dim, time) -> (B, hidden_dim, 1, chunk_size) + xs = self.embed(xs) + + att_cache = torch.split(att_cache, self.head, dim=1) + cnn_cache = self.identity_cnncache(cnn_cache) + cnn_cache = torch.split(cnn_cache, 1, dim=2) + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoders): + xs, new_att_cache, new_cnn_cache = layer(xs, + att_mask, + att_cache=att_cache[i], + cnn_cache=cnn_cache[i]) + r_att_cache.append(new_att_cache[:, :, :, self.chunk_size:]) + r_cnn_cache.append(new_cnn_cache) + r_att_cache = torch.cat(r_att_cache, dim=1) + r_cnn_cache = self.identity_cnncache(torch.cat(r_cnn_cache, dim=2)) + + xs = xs.squeeze(2).transpose(1, 2).contiguous() + xs = self.after_norm(xs) + # NOTE(xcsong): 4D in, 4D out to meet the requirment of CTC input. + xs = xs.transpose(1, 2).contiguous().unsqueeze(2) # (B, C, 1, T) + + return (xs, r_att_cache, r_cnn_cache) + + +class BPUCTC(torch.nn.Module): + """Refactor wenet/transformer/ctc.py::CTC + """ + + def __init__(self, module): + super().__init__() + # Unchanged submodules and attributes + original = copy.deepcopy(module) + self.idim = module.ctc_lo.weight.size(1) + num_class = module.ctc_lo.weight.size(0) + + # 1. Modify self.ctc_lo, Split final projection to meet the + # requirment of maximum in/out channels (2048 for XJ3) + self.ctc_lo = torch.nn.ModuleList() + self.split_size = [] + num_split = (num_class - 1) // 2048 + 1 + for idx in range(num_split): + out_channel = min(num_class, (idx + 1) * 2048) - idx * 2048 + conv_ele = torch.nn.Conv2d(self.idim, out_channel, 1, 1) + self.ctc_lo.append(conv_ele) + self.split_size.append(out_channel) + orig_weight = torch.split(module.ctc_lo.weight, self.split_size, dim=0) + orig_bias = torch.split(module.ctc_lo.bias, self.split_size, dim=0) + for i, (w, b) in enumerate(zip(orig_weight, orig_bias)): + w = w.unsqueeze(2).unsqueeze(3) + self.ctc_lo[i].weight = torch.nn.Parameter(w) + self.ctc_lo[i].bias = torch.nn.Parameter(b) + + self.check_equal(original) + + def check_equal(self, module): + random_data = torch.randn(1, 100, self.idim) + original_result = module.ctc_lo(random_data) + random_data = random_data.transpose(1, 2).unsqueeze(2) + new_result = self.forward(random_data) + np.testing.assert_allclose(to_numpy(original_result), + to_numpy( + new_result.squeeze(2).transpose(1, 2)), + rtol=1e-02, + atol=1e-03) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """frame activations, without softmax. + + Args: + Tensor x: 4d tensor (B, hidden_dim, 1, chunk_size) + Returns: + torch.Tensor: (B, num_class, 1, chunk_size) + """ + out = [] + for i, layer in enumerate(self.ctc_lo): + out.append(layer(x)) + out = torch.cat(out, dim=1) + return out + + +def export_encoder(asr_model, args): + logger.info("Stage-1: export encoder") + decode_window, mel_dim = args.decoding_window, args.feature_size + encoder = BPUConformerEncoder(asr_model.encoder, args.chunk_size, + args.num_decoding_left_chunks, + args.ln_run_on_bpu) + encoder.eval() + encoder_outpath = os.path.join(args.output_dir, 'encoder.onnx') + + logger.info("Stage-1.1: prepare inputs for encoder") + chunk = torch.randn((1, 1, decode_window, mel_dim)) + required_cache_size = encoder.chunk_size * encoder.left_chunks + kv_time = required_cache_size + encoder.chunk_size + hidden, layers = encoder._output_size, len(encoder.encoders) + head = encoder.encoders[0].self_attn.h + d_k = hidden // head + lorder = encoder.encoders[0].conv_module.lorder + att_cache = torch.zeros(1, layers * head, d_k * 2, required_cache_size) + att_mask = torch.ones((1, head, encoder.chunk_size, kv_time)) + att_mask[:, :, :, :required_cache_size] = 0 + cnn_cache = torch.zeros((1, hidden, layers, lorder)) + inputs = (chunk, att_cache, cnn_cache, att_mask) + logger.info("chunk.size(): {} att_cache.size(): {} " + "cnn_cache.size(): {} att_mask.size(): {}".format( + list(chunk.size()), list(att_cache.size()), + list(cnn_cache.size()), list(att_mask.size()))) + + logger.info("Stage-1.2: torch.onnx.export") + # NOTE(xcsong): Below attributes will be used in + # onnx2horizonbin.py::generate_config() + attributes = {} + attributes['input_name'] = "chunk;att_cache;cnn_cache;att_mask" + attributes['output_name'] = "output;r_att_cache;r_cnn_cache" + attributes['input_type'] = "featuremap;featuremap;featuremap;featuremap" + attributes['norm_type'] = \ + "no_preprocess;no_preprocess;no_preprocess;no_preprocess" + attributes['input_layout_train'] = "NCHW;NCHW;NCHW;NCHW" + attributes['input_layout_rt'] = "NCHW;NCHW;NCHW;NCHW" + attributes['input_shape'] = \ + "{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{}".format( + chunk.size(0), chunk.size(1), chunk.size(2), chunk.size(3), + att_cache.size(0), att_cache.size(1), att_cache.size(2), + att_cache.size(3), cnn_cache.size(0), cnn_cache.size(1), + cnn_cache.size(2), cnn_cache.size(3), att_mask.size(0), + att_mask.size(1), att_mask.size(2), att_mask.size(3) + ) + torch.onnx.export( # NOTE(xcsong): only support opset==11 + encoder, + inputs, + encoder_outpath, + opset_version=11, + export_params=True, + do_constant_folding=True, + input_names=attributes['input_name'].split(';'), + output_names=attributes['output_name'].split(';'), + dynamic_axes=None, + verbose=False) + onnx_encoder = onnx.load(encoder_outpath) + for k in vars(args): + meta = onnx_encoder.metadata_props.add() + meta.key, meta.value = str(k), str(getattr(args, k)) + for k in attributes: + meta = onnx_encoder.metadata_props.add() + meta.key, meta.value = str(k), str(attributes[k]) + onnx.checker.check_model(onnx_encoder) + onnx.helper.printable_graph(onnx_encoder.graph) + onnx.save(onnx_encoder, encoder_outpath) + print_input_output_info(onnx_encoder, "onnx_encoder") + logger.info('Export onnx_encoder, done! see {}'.format(encoder_outpath)) + + logger.info("Stage-1.3: check onnx_encoder and torch_encoder") + torch_output = [] + torch_chunk, torch_att_mask = copy.deepcopy(chunk), copy.deepcopy(att_mask) + torch_att_cache = copy.deepcopy(att_cache) + torch_cnn_cache = copy.deepcopy(cnn_cache) + for i in range(10): + logger.info("torch chunk-{}: {}, att_cache: {}, cnn_cache: {}" + ", att_mask: {}".format(i, list(torch_chunk.size()), + list(torch_att_cache.size()), + list(torch_cnn_cache.size()), + list(torch_att_mask.size()))) + torch_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1 + out, torch_att_cache, torch_cnn_cache = encoder( + torch_chunk, torch_att_cache, torch_cnn_cache, torch_att_mask) + torch_output.append(out) + torch_output = torch.cat(torch_output, dim=-1) + + onnx_output = [] + onnx_chunk, onnx_att_mask = to_numpy(chunk), to_numpy(att_mask) + onnx_att_cache = to_numpy(att_cache) + onnx_cnn_cache = to_numpy(cnn_cache) + ort_session = onnxruntime.InferenceSession(encoder_outpath) + input_names = [node.name for node in onnx_encoder.graph.input] + for i in range(10): + logger.info("onnx chunk-{}: {}, att_cache: {}, cnn_cache: {}," + " att_mask: {}".format(i, onnx_chunk.shape, + onnx_att_cache.shape, + onnx_cnn_cache.shape, + onnx_att_mask.shape)) + onnx_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)):] = 1 + ort_inputs = { + 'chunk': onnx_chunk, + 'att_cache': onnx_att_cache, + 'cnn_cache': onnx_cnn_cache, + 'att_mask': onnx_att_mask, + } + ort_outs = ort_session.run(None, ort_inputs) + onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2] + onnx_output.append(ort_outs[0]) + onnx_output = np.concatenate(onnx_output, axis=-1) + + np.testing.assert_allclose(to_numpy(torch_output), + onnx_output, + rtol=1e-03, + atol=1e-04) + meta = ort_session.get_modelmeta() + logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) + logger.info("Check onnx_encoder, pass!") + return encoder, ort_session + + +def export_ctc(asr_model, args): + logger.info("Stage-2: export ctc") + ctc = BPUCTC(asr_model.ctc).eval() + ctc_outpath = os.path.join(args.output_dir, 'ctc.onnx') + + logger.info("Stage-2.1: prepare inputs for ctc") + hidden = torch.randn((1, args.output_size, 1, args.chunk_size)) + + logger.info("Stage-2.2: torch.onnx.export") + # NOTE(xcsong): Below attributes will be used in + # onnx2horizonbin.py::generate_config() + attributes = {} + attributes['input_name'], attributes['input_type'] = "hidden", "featuremap" + attributes['norm_type'] = "no_preprocess" + attributes['input_layout_train'] = "NCHW" + attributes['input_layout_rt'] = "NCHW" + attributes['input_shape'] = "{}x{}x{}x{}".format( + hidden.size(0), + hidden.size(1), + hidden.size(2), + hidden.size(3), + ) + torch.onnx.export(ctc, + hidden, + ctc_outpath, + opset_version=11, + export_params=True, + do_constant_folding=True, + input_names=['hidden'], + output_names=['probs'], + dynamic_axes=None, + verbose=False) + onnx_ctc = onnx.load(ctc_outpath) + for k in vars(args): + meta = onnx_ctc.metadata_props.add() + meta.key, meta.value = str(k), str(getattr(args, k)) + for k in attributes: + meta = onnx_ctc.metadata_props.add() + meta.key, meta.value = str(k), str(attributes[k]) + onnx.checker.check_model(onnx_ctc) + onnx.helper.printable_graph(onnx_ctc.graph) + onnx.save(onnx_ctc, ctc_outpath) + print_input_output_info(onnx_ctc, "onnx_ctc") + logger.info('Export onnx_ctc, done! see {}'.format(ctc_outpath)) + + logger.info("Stage-2.3: check onnx_ctc and torch_ctc") + torch_output = ctc(hidden) + ort_session = onnxruntime.InferenceSession(ctc_outpath) + onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)}) + + np.testing.assert_allclose(to_numpy(torch_output), + onnx_output[0], + rtol=1e-03, + atol=1e-04) + meta = ort_session.get_modelmeta() + logger.info("custom_metadata_map={}".format(meta.custom_metadata_map)) + logger.info("Check onnx_ctc, pass!") + return ctc, ort_session + + +def export_decoder(asr_model, args): + logger.info("Currently, Decoder is not supported.") + + +if __name__ == '__main__': + torch.manual_seed(777) + args = get_args() + args.ln_run_on_bpu = False + # NOTE(xcsong): XJ3 BPU only support static shapes + assert args.chunk_size > 0 + assert args.num_decoding_left_chunks > 0 + os.system("mkdir -p " + args.output_dir) + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + model, configs = init_model(args, configs) + model.eval() + print(model) + + args.feature_size = configs['input_dim'] + args.output_size = model.encoder.output_size() + args.decoding_window = (args.chunk_size - 1) * \ + model.encoder.embed.subsampling_rate + \ + model.encoder.embed.right_context + 1 + + export_encoder(model, args) + export_ctc(model, args) + export_decoder(model, args) diff --git a/wenet/bin/export_onnx_cpu.py b/wenet/bin/export_onnx_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..f382545a072bb8babcb69c65913ad548d830e8a0 --- /dev/null +++ b/wenet/bin/export_onnx_cpu.py @@ -0,0 +1,470 @@ +# Copyright (c) 2022, Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import logging +import os +import copy +import sys + +import torch +import yaml +import numpy as np + +from wenet.utils.init_model import init_model + +try: + import onnx + import onnxruntime + from onnxruntime.quantization import quantize_dynamic, QuantType +except ImportError: + print('Please install onnx and onnxruntime!') + sys.exit(1) + + +def get_args(): + parser = argparse.ArgumentParser(description='export your script model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--output_dir', required=True, help='output directory') + parser.add_argument('--chunk_size', + required=True, + type=int, + help='decoding chunk size') + parser.add_argument('--num_decoding_left_chunks', + required=True, + type=int, + help='cache chunks') + parser.add_argument('--reverse_weight', + default=0.5, + type=float, + help='reverse_weight in attention_rescoing') + args = parser.parse_args() + return args + + +def to_numpy(tensor): + if tensor.requires_grad: + return tensor.detach().cpu().numpy() + else: + return tensor.cpu().numpy() + + +def print_input_output_info(onnx_model, name, prefix="\t\t"): + input_names = [node.name for node in onnx_model.graph.input] + input_shapes = [[d.dim_value for d in node.type.tensor_type.shape.dim] + for node in onnx_model.graph.input] + output_names = [node.name for node in onnx_model.graph.output] + output_shapes = [[d.dim_value for d in node.type.tensor_type.shape.dim] + for node in onnx_model.graph.output] + print("{}{} inputs : {}".format(prefix, name, input_names)) + print("{}{} input shapes : {}".format(prefix, name, input_shapes)) + print("{}{} outputs: {}".format(prefix, name, output_names)) + print("{}{} output shapes : {}".format(prefix, name, output_shapes)) + + +def export_encoder(asr_model, args): + print("Stage-1: export encoder") + encoder = asr_model.encoder + encoder.forward = encoder.forward_chunk + encoder_outpath = os.path.join(args['output_dir'], 'encoder.onnx') + + print("\tStage-1.1: prepare inputs for encoder") + chunk = torch.randn( + (args['batch'], args['decoding_window'], args['feature_size'])) + offset = 0 + # NOTE(xcsong): The uncertainty of `next_cache_start` only appears + # in the first few chunks, this is caused by dynamic att_cache shape, i,e + # (0, 0, 0, 0) for 1st chunk and (elayers, head, ?, d_k*2) for subsequent + # chunks. One way to ease the ONNX export is to keep `next_cache_start` + # as a fixed value. To do this, for the **first** chunk, if + # left_chunks > 0, we feed real cache & real mask to the model, otherwise + # fake cache & fake mask. In this way, we get: + # 1. 16/-1 mode: next_cache_start == 0 for all chunks + # 2. 16/4 mode: next_cache_start == chunk_size for all chunks + # 3. 16/0 mode: next_cache_start == chunk_size for all chunks + # 4. -1/-1 mode: next_cache_start == 0 for all chunks + # NO MORE DYNAMIC CHANGES!! + # + # NOTE(Mddct): We retain the current design for the convenience of supporting some + # inference frameworks without dynamic shapes. If you're interested in all-in-one + # model that supports different chunks please see: + # https://github.com/wenet-e2e/wenet/pull/1174 + + if args['left_chunks'] > 0: # 16/4 + required_cache_size = args['chunk_size'] * args['left_chunks'] + offset = required_cache_size + # Real cache + att_cache = torch.zeros( + (args['num_blocks'], args['head'], required_cache_size, + args['output_size'] // args['head'] * 2)) + # Real mask + att_mask = torch.ones( + (args['batch'], 1, required_cache_size + args['chunk_size']), + dtype=torch.bool) + att_mask[:, :, :required_cache_size] = 0 + elif args['left_chunks'] <= 0: # 16/-1, -1/-1, 16/0 + required_cache_size = -1 if args['left_chunks'] < 0 else 0 + # Fake cache + att_cache = torch.zeros((args['num_blocks'], args['head'], 0, + args['output_size'] // args['head'] * 2)) + # Fake mask + att_mask = torch.ones((0, 0, 0), dtype=torch.bool) + cnn_cache = torch.zeros( + (args['num_blocks'], args['batch'], args['output_size'], + args['cnn_module_kernel'] - 1)) + inputs = (chunk, offset, required_cache_size, att_cache, cnn_cache, + att_mask) + print("\t\tchunk.size(): {}\n".format(chunk.size()), + "\t\toffset: {}\n".format(offset), + "\t\trequired_cache: {}\n".format(required_cache_size), + "\t\tatt_cache.size(): {}\n".format(att_cache.size()), + "\t\tcnn_cache.size(): {}\n".format(cnn_cache.size()), + "\t\tatt_mask.size(): {}\n".format(att_mask.size())) + + print("\tStage-1.2: torch.onnx.export") + dynamic_axes = { + 'chunk': { + 1: 'T' + }, + 'att_cache': { + 2: 'T_CACHE' + }, + 'att_mask': { + 2: 'T_ADD_T_CACHE' + }, + 'output': { + 1: 'T' + }, + 'r_att_cache': { + 2: 'T_CACHE' + }, + } + # NOTE(xcsong): We keep dynamic axes even if in 16/4 mode, this is + # to avoid padding the last chunk (which usually contains less + # frames than required). For users who want static axes, just pop + # out specific axis. + # if args['chunk_size'] > 0: # 16/4, 16/-1, 16/0 + # dynamic_axes.pop('chunk') + # dynamic_axes.pop('output') + # if args['left_chunks'] >= 0: # 16/4, 16/0 + # # NOTE(xsong): since we feed real cache & real mask into the + # # model when left_chunks > 0, the shape of cache will never + # # be changed. + # dynamic_axes.pop('att_cache') + # dynamic_axes.pop('r_att_cache') + torch.onnx.export(encoder, + inputs, + encoder_outpath, + opset_version=13, + export_params=True, + do_constant_folding=True, + input_names=[ + 'chunk', 'offset', 'required_cache_size', + 'att_cache', 'cnn_cache', 'att_mask' + ], + output_names=['output', 'r_att_cache', 'r_cnn_cache'], + dynamic_axes=dynamic_axes, + verbose=False) + onnx_encoder = onnx.load(encoder_outpath) + for (k, v) in args.items(): + meta = onnx_encoder.metadata_props.add() + meta.key, meta.value = str(k), str(v) + onnx.checker.check_model(onnx_encoder) + onnx.helper.printable_graph(onnx_encoder.graph) + # NOTE(xcsong): to add those metadatas we need to reopen + # the file and resave it. + onnx.save(onnx_encoder, encoder_outpath) + print_input_output_info(onnx_encoder, "onnx_encoder") + # Dynamic quantization + model_fp32 = encoder_outpath + model_quant = os.path.join(args['output_dir'], 'encoder.quant.onnx') + quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8) + print('\t\tExport onnx_encoder, done! see {}'.format(encoder_outpath)) + + print("\tStage-1.3: check onnx_encoder and torch_encoder") + torch_output = [] + torch_chunk = copy.deepcopy(chunk) + torch_offset = copy.deepcopy(offset) + torch_required_cache_size = copy.deepcopy(required_cache_size) + torch_att_cache = copy.deepcopy(att_cache) + torch_cnn_cache = copy.deepcopy(cnn_cache) + torch_att_mask = copy.deepcopy(att_mask) + for i in range(10): + print("\t\ttorch chunk-{}: {}, offset: {}, att_cache: {}," + " cnn_cache: {}, att_mask: {}".format( + i, list(torch_chunk.size()), torch_offset, + list(torch_att_cache.size()), list(torch_cnn_cache.size()), + list(torch_att_mask.size()))) + # NOTE(xsong): att_mask of the first few batches need changes if + # we use 16/4 mode. + if args['left_chunks'] > 0: # 16/4 + torch_att_mask[:, :, -(args['chunk_size'] * (i + 1)):] = 1 + out, torch_att_cache, torch_cnn_cache = encoder( + torch_chunk, torch_offset, torch_required_cache_size, + torch_att_cache, torch_cnn_cache, torch_att_mask) + torch_output.append(out) + torch_offset += out.size(1) + torch_output = torch.cat(torch_output, dim=1) + + onnx_output = [] + onnx_chunk = to_numpy(chunk) + onnx_offset = np.array((offset)).astype(np.int64) + onnx_required_cache_size = np.array((required_cache_size)).astype(np.int64) + onnx_att_cache = to_numpy(att_cache) + onnx_cnn_cache = to_numpy(cnn_cache) + onnx_att_mask = to_numpy(att_mask) + ort_session = onnxruntime.InferenceSession( + encoder_outpath, providers=['CPUExecutionProvider']) + input_names = [node.name for node in onnx_encoder.graph.input] + for i in range(10): + print("\t\tonnx chunk-{}: {}, offset: {}, att_cache: {}," + " cnn_cache: {}, att_mask: {}".format(i, onnx_chunk.shape, + onnx_offset, + onnx_att_cache.shape, + onnx_cnn_cache.shape, + onnx_att_mask.shape)) + # NOTE(xsong): att_mask of the first few batches need changes if + # we use 16/4 mode. + if args['left_chunks'] > 0: # 16/4 + onnx_att_mask[:, :, -(args['chunk_size'] * (i + 1)):] = 1 + ort_inputs = { + 'chunk': onnx_chunk, + 'offset': onnx_offset, + 'required_cache_size': onnx_required_cache_size, + 'att_cache': onnx_att_cache, + 'cnn_cache': onnx_cnn_cache, + 'att_mask': onnx_att_mask + } + # NOTE(xcsong): If we use 16/-1, -1/-1 or 16/0 mode, `next_cache_start` + # will be hardcoded to 0 or chunk_size by ONNX, thus + # required_cache_size and att_mask are no more needed and they will + # be removed by ONNX automatically. + for k in list(ort_inputs): + if k not in input_names: + ort_inputs.pop(k) + ort_outs = ort_session.run(None, ort_inputs) + onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2] + onnx_output.append(ort_outs[0]) + onnx_offset += ort_outs[0].shape[1] + onnx_output = np.concatenate(onnx_output, axis=1) + + np.testing.assert_allclose(to_numpy(torch_output), + onnx_output, + rtol=1e-03, + atol=1e-05) + meta = ort_session.get_modelmeta() + print("\t\tcustom_metadata_map={}".format(meta.custom_metadata_map)) + print("\t\tCheck onnx_encoder, pass!") + + +def export_ctc(asr_model, args): + print("Stage-2: export ctc") + ctc = asr_model.ctc + ctc.forward = ctc.log_softmax + ctc_outpath = os.path.join(args['output_dir'], 'ctc.onnx') + + print("\tStage-2.1: prepare inputs for ctc") + hidden = torch.randn( + (args['batch'], args['chunk_size'] if args['chunk_size'] > 0 else 16, + args['output_size'])) + + print("\tStage-2.2: torch.onnx.export") + dynamic_axes = {'hidden': {1: 'T'}, 'probs': {1: 'T'}} + torch.onnx.export(ctc, + hidden, + ctc_outpath, + opset_version=13, + export_params=True, + do_constant_folding=True, + input_names=['hidden'], + output_names=['probs'], + dynamic_axes=dynamic_axes, + verbose=False) + onnx_ctc = onnx.load(ctc_outpath) + for (k, v) in args.items(): + meta = onnx_ctc.metadata_props.add() + meta.key, meta.value = str(k), str(v) + onnx.checker.check_model(onnx_ctc) + onnx.helper.printable_graph(onnx_ctc.graph) + onnx.save(onnx_ctc, ctc_outpath) + print_input_output_info(onnx_ctc, "onnx_ctc") + # Dynamic quantization + model_fp32 = ctc_outpath + model_quant = os.path.join(args['output_dir'], 'ctc.quant.onnx') + quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8) + print('\t\tExport onnx_ctc, done! see {}'.format(ctc_outpath)) + + print("\tStage-2.3: check onnx_ctc and torch_ctc") + torch_output = ctc(hidden) + ort_session = onnxruntime.InferenceSession( + ctc_outpath, providers=['CPUExecutionProvider']) + onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)}) + + np.testing.assert_allclose(to_numpy(torch_output), + onnx_output[0], + rtol=1e-03, + atol=1e-05) + print("\t\tCheck onnx_ctc, pass!") + + +def export_decoder(asr_model, args): + print("Stage-3: export decoder") + decoder = asr_model + # NOTE(lzhin): parameters of encoder will be automatically removed + # since they are not used during rescoring. + decoder.forward = decoder.forward_attention_decoder + decoder_outpath = os.path.join(args['output_dir'], 'decoder.onnx') + + print("\tStage-3.1: prepare inputs for decoder") + # hardcode time->200 nbest->10 len->20, they are dynamic axes. + encoder_out = torch.randn((1, 200, args['output_size'])) + hyps = torch.randint(low=0, high=args['vocab_size'], size=[10, 20]) + hyps[:, 0] = args['vocab_size'] - 1 # + hyps_lens = torch.randint(low=15, high=21, size=[10]) + + print("\tStage-3.2: torch.onnx.export") + dynamic_axes = { + 'hyps': { + 0: 'NBEST', + 1: 'L' + }, + 'hyps_lens': { + 0: 'NBEST' + }, + 'encoder_out': { + 1: 'T' + }, + 'score': { + 0: 'NBEST', + 1: 'L' + }, + 'r_score': { + 0: 'NBEST', + 1: 'L' + } + } + inputs = (hyps, hyps_lens, encoder_out, args['reverse_weight']) + torch.onnx.export( + decoder, + inputs, + decoder_outpath, + opset_version=13, + export_params=True, + do_constant_folding=True, + input_names=['hyps', 'hyps_lens', 'encoder_out', 'reverse_weight'], + output_names=['score', 'r_score'], + dynamic_axes=dynamic_axes, + verbose=False) + onnx_decoder = onnx.load(decoder_outpath) + for (k, v) in args.items(): + meta = onnx_decoder.metadata_props.add() + meta.key, meta.value = str(k), str(v) + onnx.checker.check_model(onnx_decoder) + onnx.helper.printable_graph(onnx_decoder.graph) + onnx.save(onnx_decoder, decoder_outpath) + print_input_output_info(onnx_decoder, "onnx_decoder") + model_fp32 = decoder_outpath + model_quant = os.path.join(args['output_dir'], 'decoder.quant.onnx') + quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8) + print('\t\tExport onnx_decoder, done! see {}'.format(decoder_outpath)) + + print("\tStage-3.3: check onnx_decoder and torch_decoder") + torch_score, torch_r_score = decoder(hyps, hyps_lens, encoder_out, + args['reverse_weight']) + ort_session = onnxruntime.InferenceSession( + decoder_outpath, providers=['CPUExecutionProvider']) + input_names = [node.name for node in onnx_decoder.graph.input] + ort_inputs = { + 'hyps': to_numpy(hyps), + 'hyps_lens': to_numpy(hyps_lens), + 'encoder_out': to_numpy(encoder_out), + 'reverse_weight': np.array((args['reverse_weight'])), + } + for k in list(ort_inputs): + if k not in input_names: + ort_inputs.pop(k) + onnx_output = ort_session.run(None, ort_inputs) + + np.testing.assert_allclose(to_numpy(torch_score), + onnx_output[0], + rtol=1e-03, + atol=1e-05) + if args['is_bidirectional_decoder'] and args['reverse_weight'] > 0.0: + np.testing.assert_allclose(to_numpy(torch_r_score), + onnx_output[1], + rtol=1e-03, + atol=1e-05) + print("\t\tCheck onnx_decoder, pass!") + + +def main(): + torch.manual_seed(777) + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + output_dir = args.output_dir + os.system("mkdir -p " + output_dir) + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + model, configs = init_model(args, configs) + model.eval() + print(model) + + arguments = {} + arguments['output_dir'] = output_dir + arguments['batch'] = 1 + arguments['chunk_size'] = args.chunk_size + arguments['left_chunks'] = args.num_decoding_left_chunks + arguments['reverse_weight'] = args.reverse_weight + arguments['output_size'] = configs['encoder_conf']['output_size'] + arguments['num_blocks'] = configs['encoder_conf']['num_blocks'] + arguments['cnn_module_kernel'] = configs['encoder_conf'].get( + 'cnn_module_kernel', 1) + arguments['head'] = configs['encoder_conf']['attention_heads'] + arguments['feature_size'] = configs['input_dim'] + arguments['vocab_size'] = configs['output_dim'] + # NOTE(xcsong): if chunk_size == -1, hardcode to 67 + arguments['decoding_window'] = (args.chunk_size - 1) * \ + model.encoder.embed.subsampling_rate + \ + model.encoder.embed.right_context + 1 if args.chunk_size > 0 else 67 + arguments['encoder'] = configs['encoder'] + arguments['decoder'] = configs['decoder'] + arguments['subsampling_rate'] = model.subsampling_rate() + arguments['right_context'] = model.right_context() + arguments['sos_symbol'] = model.sos_symbol() + arguments['eos_symbol'] = model.eos_symbol() + arguments['is_bidirectional_decoder'] = 1 \ + if model.is_bidirectional_decoder() else 0 + + # NOTE(xcsong): Please note that -1/-1 means non-streaming model! It is + # not a [16/4 16/-1 16/0] all-in-one model and it should not be used in + # streaming mode (i.e., setting chunk_size=16 in `decoder_main`). If you + # want to use 16/-1 or any other streaming mode in `decoder_main`, + # please export onnx in the same config. + if arguments['left_chunks'] > 0: + assert arguments['chunk_size'] > 0 # -1/4 not supported + + export_encoder(model, arguments) + export_ctc(model, arguments) + export_decoder(model, arguments) + + +if __name__ == '__main__': + main() diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6c1dbe0584889c96f1c4556d989149e34b2774 --- /dev/null +++ b/wenet/bin/export_onnx_gpu.py @@ -0,0 +1,1263 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import os +import sys + +import torch +import yaml +import logging + +import torch.nn.functional as F +from wenet.transformer.ctc import CTC +from wenet.transformer.decoder import TransformerDecoder +from wenet.transformer.encoder import BaseEncoder +from wenet.utils.init_model import init_model +from wenet.utils.mask import make_pad_mask + +try: + import onnxruntime +except ImportError: + print("Please install onnxruntime-gpu!") + sys.exit(1) + +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) + + +class Encoder(torch.nn.Module): + + def __init__(self, encoder: BaseEncoder, ctc: CTC, beam_size: int = 10): + super().__init__() + self.encoder = encoder + self.ctc = ctc + self.beam_size = beam_size + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ): + """Encoder + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + Returns: + encoder_out: B x T x F + encoder_out_lens: B + ctc_log_probs: B x T x V + beam_log_probs: B x T x beam_size + beam_log_probs_idx: B x T x beam_size + """ + encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1, + -1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + ctc_log_probs = self.ctc.log_softmax(encoder_out) + encoder_out_lens = encoder_out_lens.int() + beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs, + self.beam_size, + dim=2) + return ( + encoder_out, + encoder_out_lens, + ctc_log_probs, + beam_log_probs, + beam_log_probs_idx, + ) + + +class StreamingEncoder(torch.nn.Module): + + def __init__( + self, + model, + required_cache_size, + beam_size, + transformer=False, + return_ctc_logprobs=False, + ): + super().__init__() + self.ctc = model.ctc + self.subsampling_rate = model.encoder.embed.subsampling_rate + self.embed = model.encoder.embed + self.global_cmvn = model.encoder.global_cmvn + self.required_cache_size = required_cache_size + self.beam_size = beam_size + self.encoder = model.encoder + self.transformer = transformer + self.return_ctc_logprobs = return_ctc_logprobs + + def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, + cache_mask): + """Streaming Encoder + Args: + xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (torch.Tensor): offset with shape (b, 1) + 1 is retained for triton deployment + required_cache_size (int): cache size required for next chunk + compuation + > 0: actual cache size + <= 0: not allowed in streaming gpu encoder ` + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (b, elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (b, elayers, b, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) + in a batch of request, each request may have different + history cache. Cache mask is used to indidate the effective + cache for each request + Returns: + torch.Tensor: log probabilities of ctc output and cutoff by beam size + with shape (b, chunk_size, beam) + torch.Tensor: index of top beam size probabilities for each timestep + with shape (b, chunk_size, beam) + torch.Tensor: output of current input xs, + with shape (b, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + same shape (b, elayers, head, cache_t1, d_k * 2) + as the original att_cache + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + torch.Tensor: new cache mask, with same shape as the original + cache mask + """ + offset = offset.squeeze(1) + T = chunk_xs.size(1) + chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) + # B X 1 X T + chunk_mask = chunk_mask.to(chunk_xs.dtype) + # transpose batch & num_layers dim + att_cache = torch.transpose(att_cache, 0, 1) + cnn_cache = torch.transpose(cnn_cache, 0, 1) + + # rewrite encoder.forward_chunk + # <---------forward_chunk START---------> + xs = self.global_cmvn(chunk_xs) + # chunk mask is important for batch inferencing since + # different sequence in a batch has different length + xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) + cache_size = att_cache.size(3) # required cache size + masks = torch.cat((cache_mask, chunk_mask), dim=2) + index = offset - cache_size + + pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) + pos_emb = pos_emb.to(dtype=xs.dtype) + + next_cache_start = -self.required_cache_size + r_cache_mask = masks[:, :, next_cache_start:] + + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoder.encoders): + xs, _, new_att_cache, new_cnn_cache = layer( + xs, + masks, + pos_emb, + att_cache=att_cache[i], + cnn_cache=cnn_cache[i], + ) + # shape(new_att_cache) is (B, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (B, hidden-dim, cache_t2) + r_att_cache.append( + new_att_cache[:, :, next_cache_start:, :].unsqueeze(1)) + if not self.transformer: + r_cnn_cache.append(new_cnn_cache.unsqueeze(1)) + if self.encoder.normalize_before: + chunk_out = self.encoder.after_norm(xs) + else: + chunk_out = xs + + r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx + if not self.transformer: + r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers + + # <---------forward_chunk END---------> + + log_ctc_probs = self.ctc.log_softmax(chunk_out) + log_probs, log_probs_idx = torch.topk(log_ctc_probs, + self.beam_size, + dim=2) + log_probs = log_probs.to(chunk_xs.dtype) + + r_offset = offset + chunk_out.shape[1] + # the below ops not supported in Tensorrt + # chunk_out_lens = torch.div(chunk_lens, subsampling_rate, + # rounding_mode='floor') + chunk_out_lens = chunk_lens // self.subsampling_rate + r_offset = r_offset.unsqueeze(1) + if self.return_ctc_logprobs: + return ( + log_ctc_probs, + chunk_out, + chunk_out_lens, + r_offset, + r_att_cache, + r_cnn_cache, + r_cache_mask, + ) + else: + return ( + log_probs, + log_probs_idx, + chunk_out, + chunk_out_lens, + r_offset, + r_att_cache, + r_cnn_cache, + r_cache_mask, + ) + + +class StreamingSqueezeformerEncoder(torch.nn.Module): + + def __init__(self, model, required_cache_size, beam_size): + super().__init__() + self.ctc = model.ctc + self.subsampling_rate = model.encoder.embed.subsampling_rate + self.embed = model.encoder.embed + self.global_cmvn = model.encoder.global_cmvn + self.required_cache_size = required_cache_size + self.beam_size = beam_size + self.encoder = model.encoder + self.reduce_idx = model.encoder.reduce_idx + self.recover_idx = model.encoder.recover_idx + if self.reduce_idx is None: + self.time_reduce = None + else: + if self.recover_idx is None: + self.time_reduce = "normal" # no recovery at the end + else: + self.time_reduce = "recover" # recovery at the end + assert len(self.reduce_idx) == len(self.recover_idx) + + def calculate_downsampling_factor(self, i: int) -> int: + if self.reduce_idx is None: + return 1 + else: + reduce_exp, recover_exp = 0, 0 + for exp, rd_idx in enumerate(self.reduce_idx): + if i >= rd_idx: + reduce_exp = exp + 1 + if self.recover_idx is not None: + for exp, rc_idx in enumerate(self.recover_idx): + if i >= rc_idx: + recover_exp = exp + 1 + return int(2**(reduce_exp - recover_exp)) + + def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, + cache_mask): + """Streaming Encoder + Args: + xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (torch.Tensor): offset with shape (b, 1) + 1 is retained for triton deployment + required_cache_size (int): cache size required for next chunk + compuation + > 0: actual cache size + <= 0: not allowed in streaming gpu encoder ` + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (b, elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (b, elayers, b, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) + in a batch of request, each request may have different + history cache. Cache mask is used to indidate the effective + cache for each request + Returns: + torch.Tensor: log probabilities of ctc output and cutoff by beam size + with shape (b, chunk_size, beam) + torch.Tensor: index of top beam size probabilities for each timestep + with shape (b, chunk_size, beam) + torch.Tensor: output of current input xs, + with shape (b, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + same shape (b, elayers, head, cache_t1, d_k * 2) + as the original att_cache + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + torch.Tensor: new cache mask, with same shape as the original + cache mask + """ + offset = offset.squeeze(1) + T = chunk_xs.size(1) + chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) + # B X 1 X T + chunk_mask = chunk_mask.to(chunk_xs.dtype) + # transpose batch & num_layers dim + att_cache = torch.transpose(att_cache, 0, 1) + cnn_cache = torch.transpose(cnn_cache, 0, 1) + + # rewrite encoder.forward_chunk + # <---------forward_chunk START---------> + xs = self.global_cmvn(chunk_xs) + # chunk mask is important for batch inferencing since + # different sequence in a batch has different length + xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) + elayers, cache_size = att_cache.size(0), att_cache.size(3) + att_mask = torch.cat((cache_mask, chunk_mask), dim=2) + index = offset - cache_size + + pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) + pos_emb = pos_emb.to(dtype=xs.dtype) + + next_cache_start = -self.required_cache_size + r_cache_mask = att_mask[:, :, next_cache_start:] + + r_att_cache = [] + r_cnn_cache = [] + mask_pad = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + mask_pad = mask_pad.unsqueeze(1) + max_att_len: int = 0 + recover_activations: List[Tuple[torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]] = [] + index = 0 + xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int) + xs = self.encoder.preln(xs) + for i, layer in enumerate(self.encoder.encoders): + if self.reduce_idx is not None: + if self.time_reduce is not None and i in self.reduce_idx: + recover_activations.append( + (xs, att_mask, pos_emb, mask_pad)) + ( + xs, + xs_lens, + att_mask, + mask_pad, + ) = self.encoder.time_reduction_layer( + xs, xs_lens, att_mask, mask_pad) + pos_emb = pos_emb[:, ::2, :] + if self.encoder.pos_enc_layer_type == "rel_pos_repaired": + pos_emb = pos_emb[:, :xs.size(1) * 2 - 1, :] + index += 1 + + if self.recover_idx is not None: + if self.time_reduce == "recover" and i in self.recover_idx: + index -= 1 + ( + recover_tensor, + recover_att_mask, + recover_pos_emb, + recover_mask_pad, + ) = recover_activations[index] + # recover output length for ctc decode + xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2) + xs = self.encoder.time_recover_layer(xs) + recoverd_t = recover_tensor.size(1) + xs = recover_tensor + xs[:, :recoverd_t, :].contiguous() + att_mask = recover_att_mask + pos_emb = recover_pos_emb + mask_pad = recover_mask_pad + + factor = self.calculate_downsampling_factor(i) + + xs, _, new_att_cache, new_cnn_cache = layer( + xs, + att_mask, + pos_emb, + att_cache=att_cache[i][:, :, ::factor, :] + [:, :, :pos_emb.size(1) - xs.size(1), :] + if elayers > 0 else att_cache[:, :, ::factor, :], + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache, + ) + cached_att = new_att_cache[:, :, next_cache_start // factor:, :] + cached_cnn = new_cnn_cache.unsqueeze(1) + cached_att = (cached_att.unsqueeze(3).repeat(1, 1, 1, factor, + 1).flatten(2, 3)) + if i == 0: + # record length for the first block as max length + max_att_len = cached_att.size(2) + r_att_cache.append(cached_att[:, :, :max_att_len, :].unsqueeze(1)) + r_cnn_cache.append(cached_cnn) + + chunk_out = xs + r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx + r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers + + # <---------forward_chunk END---------> + + log_ctc_probs = self.ctc.log_softmax(chunk_out) + log_probs, log_probs_idx = torch.topk(log_ctc_probs, + self.beam_size, + dim=2) + log_probs = log_probs.to(chunk_xs.dtype) + + r_offset = offset + chunk_out.shape[1] + # the below ops not supported in Tensorrt + # chunk_out_lens = torch.div(chunk_lens, subsampling_rate, + # rounding_mode='floor') + chunk_out_lens = chunk_lens // self.subsampling_rate + r_offset = r_offset.unsqueeze(1) + + return ( + log_probs, + log_probs_idx, + chunk_out, + chunk_out_lens, + r_offset, + r_att_cache, + r_cnn_cache, + r_cache_mask, + ) + + +class StreamingEfficientConformerEncoder(torch.nn.Module): + + def __init__(self, model, required_cache_size, beam_size): + super().__init__() + self.ctc = model.ctc + self.subsampling_rate = model.encoder.embed.subsampling_rate + self.embed = model.encoder.embed + self.global_cmvn = model.encoder.global_cmvn + self.required_cache_size = required_cache_size + self.beam_size = beam_size + self.encoder = model.encoder + + # Efficient Conformer + self.stride_layer_idx = model.encoder.stride_layer_idx + self.stride = model.encoder.stride + self.num_blocks = model.encoder.num_blocks + self.cnn_module_kernel = model.encoder.cnn_module_kernel + + def calculate_downsampling_factor(self, i: int) -> int: + factor = 1 + for idx, stride_idx in enumerate(self.stride_layer_idx): + if i > stride_idx: + factor *= self.stride[idx] + return factor + + def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, + cache_mask): + """Streaming Encoder + Args: + chunk_xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + chunk_lens (torch.Tensor): + offset (torch.Tensor): offset with shape (b, 1) + 1 is retained for triton deployment + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (b, elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (b, elayers, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) + in a batch of request, each request may have different + history cache. Cache mask is used to indidate the effective + cache for each request + Returns: + torch.Tensor: log probabilities of ctc output and cutoff by beam size + with shape (b, chunk_size, beam) + torch.Tensor: index of top beam size probabilities for each timestep + with shape (b, chunk_size, beam) + torch.Tensor: output of current input xs, + with shape (b, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + same shape (b, elayers, head, cache_t1, d_k * 2) + as the original att_cache + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + torch.Tensor: new cache mask, with same shape as the original + cache mask + """ + offset = offset.squeeze(1) # (b, ) + offset *= self.calculate_downsampling_factor(self.num_blocks + 1) + + T = chunk_xs.size(1) + chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) # (b, 1, T) + # B X 1 X T + chunk_mask = chunk_mask.to(chunk_xs.dtype) + # transpose batch & num_layers dim + # Shape(att_cache): (elayers, b, head, cache_t1, d_k * 2) + # Shape(cnn_cache): (elayers, b, outsize, cnn_kernel) + att_cache = torch.transpose(att_cache, 0, 1) + cnn_cache = torch.transpose(cnn_cache, 0, 1) + + # rewrite encoder.forward_chunk + # <---------forward_chunk START---------> + xs = self.global_cmvn(chunk_xs) + # chunk mask is important for batch inferencing since + # different sequence in a batch has different length + xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) + cache_size = att_cache.size(3) # required cache size + masks = torch.cat((cache_mask, chunk_mask), dim=2) + att_mask = torch.cat((cache_mask, chunk_mask), dim=2) + index = offset - cache_size + + pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) + pos_emb = pos_emb.to(dtype=xs.dtype) + + next_cache_start = -self.required_cache_size + r_cache_mask = masks[:, :, next_cache_start:] + + r_att_cache = [] + r_cnn_cache = [] + mask_pad = chunk_mask.to(torch.bool) + max_att_len, max_cnn_len = ( + 0, + 0, + ) # for repeat_interleave of new_att_cache + for i, layer in enumerate(self.encoder.encoders): + factor = self.calculate_downsampling_factor(i) + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (b, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + # shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ] + att_cache_trunc = 0 + if xs.size(1) + att_cache.size(3) / factor > pos_emb.size(1): + # The time step is not divisible by the downsampling multiple + # We propose to double the chunk_size. + att_cache_trunc = (xs.size(1) + att_cache.size(3) // factor - + pos_emb.size(1) + 1) + xs, _, new_att_cache, new_cnn_cache = layer( + xs, + att_mask, + pos_emb, + mask_pad=mask_pad, + att_cache=att_cache[i][:, :, ::factor, :][:, :, + att_cache_trunc:, :], + cnn_cache=cnn_cache[i, :, :, :] + if cnn_cache.size(0) > 0 else cnn_cache, + ) + + if i in self.stride_layer_idx: + # compute time dimension for next block + efficient_index = self.stride_layer_idx.index(i) + att_mask = att_mask[:, ::self.stride[efficient_index], ::self. + stride[efficient_index], ] + mask_pad = mask_pad[:, ::self.stride[efficient_index], ::self. + stride[efficient_index], ] + pos_emb = pos_emb[:, ::self.stride[efficient_index], :] + + # shape(new_att_cache) = [batch, head, time2, outdim] + new_att_cache = new_att_cache[:, :, next_cache_start // factor:, :] + # shape(new_cnn_cache) = [batch, 1, outdim, cache_t2] + new_cnn_cache = new_cnn_cache.unsqueeze(1) # shape(1):layerID + + # use repeat_interleave to new_att_cache + # new_att_cache = new_att_cache.repeat_interleave(repeats=factor, dim=2) + new_att_cache = (new_att_cache.unsqueeze(3).repeat( + 1, 1, 1, factor, 1).flatten(2, 3)) + # padding new_cnn_cache to cnn.lorder for casual convolution + new_cnn_cache = F.pad( + new_cnn_cache, + (self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0), + ) + + if i == 0: + # record length for the first block as max length + max_att_len = new_att_cache.size(2) + max_cnn_len = new_cnn_cache.size(3) + + # update real shape of att_cache and cnn_cache + r_att_cache.append(new_att_cache[:, :, + -max_att_len:, :].unsqueeze(1)) + r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:]) + + if self.encoder.normalize_before: + chunk_out = self.encoder.after_norm(xs) + else: + chunk_out = xs + + # shape of r_att_cache: (b, elayers, head, time2, outdim) + r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx + # shape of r_cnn_cache: (b, elayers, outdim, cache_t2) + r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers + + # <---------forward_chunk END---------> + + log_ctc_probs = self.ctc.log_softmax(chunk_out) + log_probs, log_probs_idx = torch.topk(log_ctc_probs, + self.beam_size, + dim=2) + log_probs = log_probs.to(chunk_xs.dtype) + + r_offset = offset + chunk_out.shape[1] + # the below ops not supported in Tensorrt + # chunk_out_lens = torch.div(chunk_lens, subsampling_rate, + # rounding_mode='floor') + chunk_out_lens = ( + chunk_lens // self.subsampling_rate // + self.calculate_downsampling_factor(self.num_blocks + 1)) + chunk_out_lens += 1 + r_offset = r_offset.unsqueeze(1) + + return ( + log_probs, + log_probs_idx, + chunk_out, + chunk_out_lens, + r_offset, + r_att_cache, + r_cnn_cache, + r_cache_mask, + ) + + +class Decoder(torch.nn.Module): + + def __init__( + self, + decoder: TransformerDecoder, + ctc_weight: float = 0.5, + reverse_weight: float = 0.0, + beam_size: int = 10, + decoder_fastertransformer: bool = False, + ): + super().__init__() + self.decoder = decoder + self.ctc_weight = ctc_weight + self.reverse_weight = reverse_weight + self.beam_size = beam_size + self.decoder_fastertransformer = decoder_fastertransformer + + def forward( + self, + encoder_out: torch.Tensor, + encoder_lens: torch.Tensor, + hyps_pad_sos_eos: torch.Tensor, + hyps_lens_sos: torch.Tensor, + r_hyps_pad_sos_eos: torch.Tensor, + ctc_score: torch.Tensor, + ): + """Encoder + Args: + encoder_out: B x T x F + encoder_lens: B + hyps_pad_sos_eos: B x beam x (T2+1), + hyps with sos & eos and padded by ignore id + hyps_lens_sos: B x beam, length for each hyp with sos + r_hyps_pad_sos_eos: B x beam x (T2+1), + reversed hyps with sos & eos and padded by ignore id + ctc_score: B x beam, ctc score for each hyp + Returns: + decoder_out: B x beam x T2 x V + r_decoder_out: B x beam x T2 x V + best_index: B + """ + B, T, F = encoder_out.shape + bz = self.beam_size + B2 = B * bz + encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) + encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1) + encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T) + T2 = hyps_pad_sos_eos.shape[2] - 1 + hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1) + hyps_lens = hyps_lens_sos.view(B2, ) + hyps_pad_sos = hyps_pad[:, :-1].contiguous() + hyps_pad_eos = hyps_pad[:, 1:].contiguous() + + r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1) + r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous() + r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous() + + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, + encoder_mask, + hyps_pad_sos, + hyps_lens, + r_hyps_pad_sos, + self.reverse_weight, + ) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + V = decoder_out.shape[-1] + decoder_out = decoder_out.view(B2, T2, V) + mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2 + # mask index, remove ignore id + index = torch.unsqueeze(hyps_pad_eos * mask, 2) + score = decoder_out.gather(2, index).squeeze(2) # B2 X T2 + # mask padded part + score = score * mask + decoder_out = decoder_out.view(B, bz, T2, V) + if self.reverse_weight > 0: + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, + dim=-1) + r_decoder_out = r_decoder_out.view(B2, T2, V) + index = torch.unsqueeze(r_hyps_pad_eos * mask, 2) + r_score = r_decoder_out.gather(2, index).squeeze(2) + r_score = r_score * mask + score = (score * (1 - self.reverse_weight) + + self.reverse_weight * r_score) + r_decoder_out = r_decoder_out.view(B, bz, T2, V) + score = torch.sum(score, axis=1) # B2 + score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score + best_index = torch.argmax(score, dim=1) + if self.decoder_fastertransformer: + return decoder_out, best_index + else: + return best_index + + +def to_numpy(tensors): + out = [] + if type(tensors) == torch.tensor: + tensors = [tensors] + for tensor in tensors: + if tensor.requires_grad: + tensor = tensor.detach().cpu().numpy() + else: + tensor = tensor.cpu().numpy() + out.append(tensor) + return out + + +def test(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): + for a, b in zip(xlist, blist): + try: + torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + except AssertionError as error: + if tolerate_small_mismatch: + print(error) + else: + raise + + +def export_offline_encoder(model, configs, args, logger, encoder_onnx_path): + bz = 32 + seq_len = 100 + beam_size = args.beam_size + feature_size = configs["input_dim"] + + speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32) + speech_lens = torch.randint(low=10, + high=seq_len, + size=(bz, ), + dtype=torch.int32) + encoder = Encoder(model.encoder, model.ctc, beam_size) + encoder.eval() + + torch.onnx.export( + encoder, + (speech, speech_lens), + encoder_onnx_path, + export_params=True, + opset_version=13, + do_constant_folding=True, + input_names=["speech", "speech_lengths"], + output_names=[ + "encoder_out", + "encoder_out_lens", + "ctc_log_probs", + "beam_log_probs", + "beam_log_probs_idx", + ], + dynamic_axes={ + "speech": { + 0: "B", + 1: "T" + }, + "speech_lengths": { + 0: "B" + }, + "encoder_out": { + 0: "B", + 1: "T_OUT" + }, + "encoder_out_lens": { + 0: "B" + }, + "ctc_log_probs": { + 0: "B", + 1: "T_OUT" + }, + "beam_log_probs": { + 0: "B", + 1: "T_OUT" + }, + "beam_log_probs_idx": { + 0: "B", + 1: "T_OUT" + }, + }, + verbose=False, + ) + + with torch.no_grad(): + o0, o1, o2, o3, o4 = encoder(speech, speech_lens) + + providers = ["CUDAExecutionProvider"] + ort_session = onnxruntime.InferenceSession(encoder_onnx_path, + providers=providers) + ort_inputs = { + "speech": to_numpy(speech), + "speech_lengths": to_numpy(speech_lens), + } + ort_outs = ort_session.run(None, ort_inputs) + + # check encoder output + test(to_numpy([o0, o1, o2, o3, o4]), ort_outs) + logger.info("export offline onnx encoder succeed!") + onnx_config = { + "beam_size": args.beam_size, + "reverse_weight": args.reverse_weight, + "ctc_weight": args.ctc_weight, + "fp16": args.fp16, + } + return onnx_config + + +def export_online_encoder(model, configs, args, logger, encoder_onnx_path): + decoding_chunk_size = args.decoding_chunk_size + subsampling = model.encoder.embed.subsampling_rate + context = model.encoder.embed.right_context + 1 + decoding_window = (decoding_chunk_size - 1) * subsampling + context + batch_size = 32 + audio_len = decoding_window + feature_size = configs["input_dim"] + output_size = configs["encoder_conf"]["output_size"] + num_layers = configs["encoder_conf"]["num_blocks"] + # in transformer the cnn module will not be available + transformer = False + cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1 + if not cnn_module_kernel: + transformer = True + num_decoding_left_chunks = args.num_decoding_left_chunks + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + if configs["encoder"] == "squeezeformer": + encoder = StreamingSqueezeformerEncoder(model, required_cache_size, + args.beam_size) + elif configs["encoder"] == "efficientConformer": + encoder = StreamingEfficientConformerEncoder(model, + required_cache_size, + args.beam_size) + else: + encoder = StreamingEncoder( + model, + required_cache_size, + args.beam_size, + transformer, + args.return_ctc_logprobs, + ) + encoder.eval() + + # begin to export encoder + chunk_xs = torch.randn(batch_size, + audio_len, + feature_size, + dtype=torch.float32) + chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len + + offset = torch.arange(0, batch_size).unsqueeze(1) + # (elayers, b, head, cache_t1, d_k * 2) + head = configs["encoder_conf"]["attention_heads"] + d_k = configs["encoder_conf"]["output_size"] // head + att_cache = torch.randn( + batch_size, + num_layers, + head, + required_cache_size, + d_k * 2, + dtype=torch.float32, + ) + cnn_cache = torch.randn( + batch_size, + num_layers, + output_size, + cnn_module_kernel, + dtype=torch.float32, + ) + + cache_mask = torch.ones(batch_size, + 1, + required_cache_size, + dtype=torch.float32) + input_names = [ + "chunk_xs", + "chunk_lens", + "offset", + "att_cache", + "cnn_cache", + "cache_mask", + ] + output_names = [ + "log_probs", + "log_probs_idx", + "chunk_out", + "chunk_out_lens", + "r_offset", + "r_att_cache", + "r_cnn_cache", + "r_cache_mask", + ] + if args.return_ctc_logprobs: + output_names = [ + "ctc_log_probs", + "chunk_out", + "chunk_out_lens", + "r_offset", + "r_att_cache", + "r_cnn_cache", + "r_cache_mask", + ] + input_tensors = ( + chunk_xs, + chunk_lens, + offset, + att_cache, + cnn_cache, + cache_mask, + ) + if transformer: + assert (args.return_ctc_logprobs is + False), "return_ctc_logprobs is not supported in transformer" + output_names.pop(6) + + all_names = input_names + output_names + dynamic_axes = {} + for name in all_names: + # only the first dimension is dynamic + # all other dimension is fixed + dynamic_axes[name] = {0: "B"} + + torch.onnx.export( + encoder, + input_tensors, + encoder_onnx_path, + export_params=True, + opset_version=14, + do_constant_folding=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + verbose=False, + ) + + with torch.no_grad(): + torch_outs = encoder(chunk_xs, chunk_lens, offset, att_cache, + cnn_cache, cache_mask) + if transformer: + torch_outs = list(torch_outs).pop(6) + ort_session = onnxruntime.InferenceSession( + encoder_onnx_path, providers=["CUDAExecutionProvider"]) + ort_inputs = {} + + input_tensors = to_numpy(input_tensors) + for idx, name in enumerate(input_names): + ort_inputs[name] = input_tensors[idx] + if transformer: + del ort_inputs["cnn_cache"] + ort_outs = ort_session.run(None, ort_inputs) + test(to_numpy(torch_outs), ort_outs, rtol=1e-03, atol=1e-05) + logger.info("export to onnx streaming encoder succeed!") + onnx_config = { + "subsampling_rate": subsampling, + "context": context, + "decoding_chunk_size": decoding_chunk_size, + "num_decoding_left_chunks": num_decoding_left_chunks, + "beam_size": args.beam_size, + "fp16": args.fp16, + "feat_size": feature_size, + "decoding_window": decoding_window, + "cnn_module_kernel_cache": cnn_module_kernel, + "return_ctc_logprobs": args.return_ctc_logprobs, + } + return onnx_config + + +def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path, + decoder_fastertransformer): + bz, seq_len = 32, 100 + beam_size = args.beam_size + decoder = Decoder( + model.decoder, + model.ctc_weight, + model.reverse_weight, + beam_size, + decoder_fastertransformer, + ) + decoder.eval() + + hyps_pad_sos_eos = torch.randint(low=3, + high=1000, + size=(bz, beam_size, seq_len)) + hyps_lens_sos = torch.randint(low=3, + high=seq_len, + size=(bz, beam_size), + dtype=torch.int32) + r_hyps_pad_sos_eos = torch.randint(low=3, + high=1000, + size=(bz, beam_size, seq_len)) + + output_size = configs["encoder_conf"]["output_size"] + encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32) + encoder_out_lens = torch.randint(low=3, + high=seq_len, + size=(bz, ), + dtype=torch.int32) + ctc_score = torch.randn(bz, beam_size, dtype=torch.float32) + + input_names = [ + "encoder_out", + "encoder_out_lens", + "hyps_pad_sos_eos", + "hyps_lens_sos", + "r_hyps_pad_sos_eos", + "ctc_score", + ] + output_names = ["best_index"] + if decoder_fastertransformer: + output_names.insert(0, "decoder_out") + + torch.onnx.export( + decoder, + ( + encoder_out, + encoder_out_lens, + hyps_pad_sos_eos, + hyps_lens_sos, + r_hyps_pad_sos_eos, + ctc_score, + ), + decoder_onnx_path, + export_params=True, + opset_version=13, + do_constant_folding=True, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "encoder_out": { + 0: "B", + 1: "T" + }, + "encoder_out_lens": { + 0: "B" + }, + "hyps_pad_sos_eos": { + 0: "B", + 2: "T2" + }, + "hyps_lens_sos": { + 0: "B" + }, + "r_hyps_pad_sos_eos": { + 0: "B", + 2: "T2" + }, + "ctc_score": { + 0: "B" + }, + "best_index": { + 0: "B" + }, + }, + verbose=False, + ) + with torch.no_grad(): + o0 = decoder( + encoder_out, + encoder_out_lens, + hyps_pad_sos_eos, + hyps_lens_sos, + r_hyps_pad_sos_eos, + ctc_score, + ) + providers = ["CUDAExecutionProvider"] + ort_session = onnxruntime.InferenceSession(decoder_onnx_path, + providers=providers) + + input_tensors = [ + encoder_out, + encoder_out_lens, + hyps_pad_sos_eos, + hyps_lens_sos, + r_hyps_pad_sos_eos, + ctc_score, + ] + ort_inputs = {} + input_tensors = to_numpy(input_tensors) + for idx, name in enumerate(input_names): + ort_inputs[name] = input_tensors[idx] + + # if model.reverse weight == 0, + # the r_hyps_pad will be removed + # from the onnx decoder since it doen't play any role + if model.reverse_weight == 0: + del ort_inputs["r_hyps_pad_sos_eos"] + ort_outs = ort_session.run(None, ort_inputs) + + # check decoder output + if decoder_fastertransformer: + test(to_numpy(o0), ort_outs, rtol=1e-03, atol=1e-05) + else: + test(to_numpy([o0]), ort_outs, rtol=1e-03, atol=1e-05) + logger.info("export to onnx decoder succeed!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="export x86_gpu model") + parser.add_argument("--config", required=True, help="config file") + parser.add_argument("--checkpoint", required=True, help="checkpoint model") + parser.add_argument( + "--cmvn_file", + required=False, + default="", + type=str, + help="global_cmvn file, default path is in config file", + ) + parser.add_argument( + "--reverse_weight", + default=-1.0, + type=float, + required=False, + help="reverse weight for bitransformer," + + "default value is in config file", + ) + parser.add_argument( + "--ctc_weight", + default=-1.0, + type=float, + required=False, + help="ctc weight, default value is in config file", + ) + parser.add_argument( + "--beam_size", + default=10, + type=int, + required=False, + help="beam size would be ctc output size", + ) + parser.add_argument( + "--output_onnx_dir", + default="onnx_model", + help="output onnx encoder and decoder directory", + ) + parser.add_argument( + "--fp16", + action="store_true", + help="whether to export fp16 model, default false", + ) + # arguments for streaming encoder + parser.add_argument( + "--streaming", + action="store_true", + help="whether to export streaming encoder, default false", + ) + parser.add_argument( + "--decoding_chunk_size", + default=16, + type=int, + required=False, + help="the decoding chunk size, <=0 is not supported", + ) + parser.add_argument( + "--num_decoding_left_chunks", + default=5, + type=int, + required=False, + help="number of left chunks, <= 0 is not supported", + ) + parser.add_argument( + "--decoder_fastertransformer", + action="store_true", + help="return decoder_out and best_index for ft", + ) + parser.add_argument( + "--return_ctc_logprobs", + action="store_true", + help="return full ctc_log_probs for TLG streaming encoder", + ) + args = parser.parse_args() + + torch.manual_seed(0) + torch.set_printoptions(precision=10) + + with open(args.config, "r") as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if args.cmvn_file and os.path.exists(args.cmvn_file): + if 'cmvn' not in configs: + configs['cmvn'] = "global_cmvn" + configs['cmvn_conf'] = {} + else: + assert configs['cmvn'] == "global_cmvn" + assert configs['cmvn_conf'] is not None + configs['cmvn_conf']["cmvn_file"] = args.cmvn_file + if (args.reverse_weight != -1.0 + and "reverse_weight" in configs["model_conf"]): + configs["model_conf"]["reverse_weight"] = args.reverse_weight + print("Update reverse weight to", args.reverse_weight) + if args.ctc_weight != -1: + print("Update ctc weight to ", args.ctc_weight) + configs["model_conf"]["ctc_weight"] = args.ctc_weight + configs["encoder_conf"]["use_dynamic_chunk"] = False + + model, configs = init_model(args, configs) + model.eval() + + if not os.path.exists(args.output_onnx_dir): + os.mkdir(args.output_onnx_dir) + encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder.onnx") + export_enc_func = None + if args.streaming: + assert args.decoding_chunk_size > 0 + assert args.num_decoding_left_chunks > 0 + export_enc_func = export_online_encoder + else: + export_enc_func = export_offline_encoder + + onnx_config = export_enc_func(model, configs, args, logger, + encoder_onnx_path) + + decoder_onnx_path = os.path.join(args.output_onnx_dir, "decoder.onnx") + export_rescoring_decoder( + model, + configs, + args, + logger, + decoder_onnx_path, + args.decoder_fastertransformer, + ) + + if args.fp16: + try: + import onnxmltools + from onnxmltools.utils.float16_converter import ( + convert_float_to_float16, ) + except ImportError: + print("Please install onnxmltools!") + sys.exit(1) + encoder_onnx_model = onnxmltools.utils.load_model(encoder_onnx_path) + encoder_onnx_model = convert_float_to_float16(encoder_onnx_model) + encoder_onnx_path = os.path.join(args.output_onnx_dir, + "encoder_fp16.onnx") + onnxmltools.utils.save_model(encoder_onnx_model, encoder_onnx_path) + decoder_onnx_model = onnxmltools.utils.load_model(decoder_onnx_path) + decoder_onnx_model = convert_float_to_float16(decoder_onnx_model) + decoder_onnx_path = os.path.join(args.output_onnx_dir, + "decoder_fp16.onnx") + onnxmltools.utils.save_model(decoder_onnx_model, decoder_onnx_path) + # dump configurations + + config_dir = os.path.join(args.output_onnx_dir, "config.yaml") + with open(config_dir, "w") as out: + yaml.dump(onnx_config, out) diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py new file mode 100644 index 0000000000000000000000000000000000000000..1782b8ab125164fe9fd554069577d13d44e255e2 --- /dev/null +++ b/wenet/bin/recognize.py @@ -0,0 +1,336 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os + +import torch +import yaml +from torch.utils.data import DataLoader + +from wenet.dataset.dataset import Dataset +from wenet.utils.config import override_config +from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer +from wenet.utils.context_graph import ContextGraph +from wenet.utils.ctc_utils import get_blank_id +from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu + + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--data_type', + default='raw', + # choices=['raw', 'shard'], + help='train and cv data type') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--device', + type=str, + default="cpu", + choices=["cpu", "npu", "cuda"], + help='accelerator to use') + parser.add_argument('--dtype', + type=str, + default='fp32', + choices=['fp16', 'fp32', 'bf16'], + help='model\'s dtype') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--beam_size', + type=int, + default=10, + help='beam size for search') + parser.add_argument('--length_penalty', + type=float, + default=0.0, + help='length penalty') + parser.add_argument('--blank_penalty', + type=float, + default=0.0, + help='blank penalty') + parser.add_argument('--result_dir', required=True, help='asr result file') + parser.add_argument('--batch_size', + type=int, + default=16, + help='asr result file') + parser.add_argument('--modes', + nargs='+', + help="""decoding mode, support the following: + attention + ctc_greedy_search + ctc_prefix_beam_search + attention_rescoring + rnnt_greedy_search + rnnt_beam_search + rnnt_beam_attn_rescoring + ctc_beam_td_attn_rescoring + hlg_onebest + hlg_rescore + paraformer_greedy_search + paraformer_beam_search""") + parser.add_argument('--search_ctc_weight', + type=float, + default=1.0, + help='ctc weight for nbest generation') + parser.add_argument('--search_transducer_weight', + type=float, + default=0.0, + help='transducer weight for nbest generation') + parser.add_argument('--ctc_weight', + type=float, + default=0.0, + help='ctc weight for rescoring weight in \ + attention rescoring decode mode \ + ctc weight for rescoring weight in \ + transducer attention rescore decode mode') + + parser.add_argument('--transducer_weight', + type=float, + default=0.0, + help='transducer weight for rescoring weight in ' + 'transducer attention rescore mode') + parser.add_argument('--attn_weight', + type=float, + default=0.0, + help='attention weight for rescoring weight in ' + 'transducer attention rescore mode') + parser.add_argument('--decoding_chunk_size', + type=int, + default=-1, + help='''decoding chunk size, + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here''') + parser.add_argument('--num_decoding_left_chunks', + type=int, + default=-1, + help='number of left chunks for decoding') + parser.add_argument('--simulate_streaming', + action='store_true', + help='simulate streaming inference') + parser.add_argument('--reverse_weight', + type=float, + default=0.0, + help='''right to left weight for attention rescoring + decode mode''') + parser.add_argument('--override_config', + action='append', + default=[], + help="override yaml config") + + parser.add_argument('--word', + default='', + type=str, + help='word file, only used for hlg decode') + parser.add_argument('--hlg', + default='', + type=str, + help='hlg file, only used for hlg decode') + parser.add_argument('--lm_scale', + type=float, + default=0.0, + help='lm scale for hlg attention rescore decode') + parser.add_argument('--decoder_scale', + type=float, + default=0.0, + help='lm scale for hlg attention rescore decode') + parser.add_argument('--r_decoder_scale', + type=float, + default=0.0, + help='lm scale for hlg attention rescore decode') + + parser.add_argument( + '--context_bias_mode', + type=str, + default='', + help='''Context bias mode, selectable from the following + option: decoding-graph, deep-biasing''') + parser.add_argument('--context_list_path', + type=str, + default='', + help='Context list path') + parser.add_argument('--context_graph_score', + type=float, + default=0.0, + help='''The higher the score, the greater the degree of + bias using decoding-graph for biasing''') + + parser.add_argument('--use_lora', + type=bool, + default=False, + help='''Whether to use lora for biasing''') + parser.add_argument("--lora_ckpt_path", + default=None, + type=str, + help="lora checkpoint path.") + + parser.add_argument('--task', + type=str, + default='asr', + help='Context list path') + parser.add_argument('--lang', + type=str, + default='zh', + help='Context list path') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + if args.gpu != -1: + # remain the original usage of gpu + args.device = "cuda" + if "cuda" in args.device: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if len(args.override_config) > 0: + configs = override_config(configs, args.override_config) + + test_conf = copy.deepcopy(configs['dataset_conf']) + + test_conf['filter_conf']['max_length'] = 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['filter_conf']['token_max_length'] = 102400 + test_conf['filter_conf']['token_min_length'] = 0 + test_conf['filter_conf']['max_output_input_ratio'] = 102400 + test_conf['filter_conf']['min_output_input_ratio'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['spec_sub'] = False + test_conf['spec_trim'] = False + test_conf['shuffle'] = False + test_conf['sort'] = False + test_conf['cycle'] = 1 + test_conf['list_shuffle'] = False + if 'fbank_conf' in test_conf: + test_conf['fbank_conf']['dither'] = 0.0 + elif 'mfcc_conf' in test_conf: + test_conf['mfcc_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_type'] = "static" + test_conf['batch_conf']['batch_size'] = args.batch_size + + tokenizer = init_tokenizer(configs) + test_dataset = Dataset(args.data_type, + args.test_data, + tokenizer, + test_conf, + partition=False) + + test_data_loader = DataLoader(test_dataset, + batch_size=None, + num_workers=args.num_workers) + + # Init asr model from configs + args.jit = False + model, configs = init_model(args, configs) + + device = torch.device(args.device) + model = model.to(device) + model.eval() + dtype = torch.float32 + if args.dtype == 'fp16': + dtype = torch.float16 + elif args.dtype == 'bf16': + dtype = torch.bfloat16 + logging.info("compute dtype is {}".format(dtype)) + + context_graph = None + if 'decoding-graph' in args.context_bias_mode: + context_graph = ContextGraph(args.context_list_path, + tokenizer.symbol_table, + configs['tokenizer_conf']['bpe_path'], + args.context_graph_score) + + _, blank_id = get_blank_id(configs, tokenizer.symbol_table) + logging.info("blank_id is {}".format(blank_id)) + + # TODO(Dinghao Zhou): Support RNN-T related decoding + # TODO(Lv Xiang): Support k2 related decoding + # TODO(Kaixun Huang): Support context graph + files = {} + for mode in args.modes: + dir_name = os.path.join(args.result_dir, mode) + os.makedirs(dir_name, exist_ok=True) + file_name = os.path.join(dir_name, 'text') + files[mode] = open(file_name, 'w', encoding='utf-8') + max_format_len = max([len(mode) for mode in args.modes]) + + with torch.cuda.amp.autocast(enabled=True, + dtype=dtype, + cache_enabled=False): + with torch.no_grad(): + utt_num=0 + # logging.info(f'utt_num: {utt_num}') + for batch_idx, batch in enumerate(test_data_loader): + keys = batch["keys"] + feats = batch["feats"].to(device) + target = batch["target"].to(device) + feats_lengths = batch["feats_lengths"].to(device) + target_lengths = batch["target_lengths"].to(device) + batch_size = feats.size(0) + # task_list = ["transcribe" for i in range(batch_size)] + task_list = [args.task for i in range(batch_size)] + lang_list = [args.lang for i in range(batch_size)] + infos = {"tasks": task_list, "langs":lang_list} + results = model.decode( + args.modes, + feats, + feats_lengths, + args.beam_size, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + ctc_weight=args.ctc_weight, + simulate_streaming=args.simulate_streaming, + reverse_weight=args.reverse_weight, + context_graph=context_graph, + blank_id=blank_id, + blank_penalty=args.blank_penalty, + length_penalty=args.length_penalty, + infos=infos) + for i, key in enumerate(keys): + utt_num += 1 + for mode, hyps in results.items(): + tokens = hyps[i].tokens + line = '{} {}'.format(key, + tokenizer.detokenize(tokens)[0]) + logging.info('{} {}'.format(mode.ljust(max_format_len), + line)) + files[mode].write(line + '\n') + # if utt_num % 500 == 0: + # files[mode].flush() + for mode, f in files.items(): + f.flush() # 强制将缓冲区内容刷新到文件 + f.close() + + +if __name__ == '__main__': + main() diff --git a/wenet/bin/recognize4llmasr.py b/wenet/bin/recognize4llmasr.py new file mode 100644 index 0000000000000000000000000000000000000000..725c41bb19a7b685dea2a0def9b02bf01c16cadf --- /dev/null +++ b/wenet/bin/recognize4llmasr.py @@ -0,0 +1,340 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os + +import torch +import yaml +from gxl_ai_utils.utils.utils_model import set_random_seed +from torch.utils.data import DataLoader + +from wenet.dataset.dataset import Dataset +from wenet.llm_asr.llmasr_model import LLMASR_Model +from wenet.utils.config import override_config +from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer +from wenet.utils.context_graph import ContextGraph +from wenet.utils.ctc_utils import get_blank_id +from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu + + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--data_type', + default='raw', + # choices=['raw', 'shard'], + help='train and cv data type') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--device', + type=str, + default="cpu", + choices=["cpu", "npu", "cuda"], + help='accelerator to use') + parser.add_argument('--dtype', + type=str, + default='fp32', + choices=['fp16', 'fp32', 'bf16'], + help='model\'s dtype') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--beam_size', + type=int, + default=10, + help='beam size for search') + parser.add_argument('--length_penalty', + type=float, + default=0.0, + help='length penalty') + parser.add_argument('--blank_penalty', + type=float, + default=0.0, + help='blank penalty') + parser.add_argument('--result_dir', required=True, help='asr result file') + parser.add_argument('--batch_size', + type=int, + default=16, + help='asr result file') + parser.add_argument('--modes', + nargs='+', + help="""decoding mode, support the following: + attention + ctc_greedy_search + ctc_prefix_beam_search + attention_rescoring + rnnt_greedy_search + rnnt_beam_search + rnnt_beam_attn_rescoring + ctc_beam_td_attn_rescoring + hlg_onebest + hlg_rescore + paraformer_greedy_search + paraformer_beam_search""") + parser.add_argument('--search_ctc_weight', + type=float, + default=1.0, + help='ctc weight for nbest generation') + parser.add_argument('--search_transducer_weight', + type=float, + default=0.0, + help='transducer weight for nbest generation') + parser.add_argument('--ctc_weight', + type=float, + default=0.0, + help='ctc weight for rescoring weight in \ + attention rescoring decode mode \ + ctc weight for rescoring weight in \ + transducer attention rescore decode mode') + + parser.add_argument('--transducer_weight', + type=float, + default=0.0, + help='transducer weight for rescoring weight in ' + 'transducer attention rescore mode') + parser.add_argument('--attn_weight', + type=float, + default=0.0, + help='attention weight for rescoring weight in ' + 'transducer attention rescore mode') + parser.add_argument('--decoding_chunk_size', + type=int, + default=-1, + help='''decoding chunk size, + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here''') + parser.add_argument('--num_decoding_left_chunks', + type=int, + default=-1, + help='number of left chunks for decoding') + parser.add_argument('--simulate_streaming', + action='store_true', + help='simulate streaming inference') + parser.add_argument('--reverse_weight', + type=float, + default=0.0, + help='''right to left weight for attention rescoring + decode mode''') + parser.add_argument('--override_config', + action='append', + default=[], + help="override yaml config") + + parser.add_argument('--word', + default='', + type=str, + help='word file, only used for hlg decode') + parser.add_argument('--hlg', + default='', + type=str, + help='hlg file, only used for hlg decode') + parser.add_argument('--lm_scale', + type=float, + default=0.0, + help='lm scale for hlg attention rescore decode') + parser.add_argument('--decoder_scale', + type=float, + default=0.0, + help='lm scale for hlg attention rescore decode') + parser.add_argument('--r_decoder_scale', + type=float, + default=0.0, + help='lm scale for hlg attention rescore decode') + + parser.add_argument( + '--context_bias_mode', + type=str, + default='', + help='''Context bias mode, selectable from the following + option: decoding-graph, deep-biasing''') + parser.add_argument('--context_list_path', + type=str, + default='', + help='Context list path') + parser.add_argument('--context_graph_score', + type=float, + default=0.0, + help='''The higher the score, the greater the degree of + bias using decoding-graph for biasing''') + + parser.add_argument('--use_lora', + type=bool, + default=False, + help='''Whether to use lora for biasing''') + parser.add_argument("--lora_ckpt_path", + default=None, + type=str, + help="lora checkpoint path.") + + parser.add_argument('--task', + type=str, + default='asr', + help='Context list path') + parser.add_argument('--lang', + type=str, + default='zh', + help='Context list path') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + set_random_seed(777) + + if args.gpu != -1: + # remain the original usage of gpu + args.device = "cuda" + if "cuda" in args.device: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if len(args.override_config) > 0: + configs = override_config(configs, args.override_config) + configs['dataset_conf']['filter_conf']['filter_no_extra_info'] = False + test_conf = copy.deepcopy(configs['dataset_conf']) + + test_conf['filter_conf']['max_length'] = 3000 # whisper最长处理30s 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['filter_conf']['token_max_length'] = 102400 + test_conf['filter_conf']['token_min_length'] = 0 + test_conf['filter_conf']['max_output_input_ratio'] = 102400 + test_conf['filter_conf']['min_output_input_ratio'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['spec_sub'] = False + test_conf['spec_trim'] = False + test_conf['shuffle'] = True + test_conf['sort'] = False + test_conf['cycle'] = 1 + test_conf['list_shuffle'] = True + if 'fbank_conf' in test_conf: + test_conf['fbank_conf']['dither'] = 0.0 + elif 'mfcc_conf' in test_conf: + test_conf['mfcc_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_type'] = "static" + test_conf['batch_conf']['batch_size'] = 1 + test_conf['split_num'] = 1 + + + tokenizer = init_tokenizer(configs) + test_dataset = Dataset(args.data_type, + args.test_data, + tokenizer, + test_conf, + partition=False) + + test_data_loader = DataLoader(test_dataset, + batch_size=None, + num_workers=args.num_workers) + + # Init asr model from configs + args.jit = False + model, configs = init_model(args, configs) + + device = torch.device(args.device) + model:LLMASR_Model = model.to(device) + model.eval() + dtype = torch.float32 + if args.dtype == 'fp16': + dtype = torch.float16 + elif args.dtype == 'bf16': + dtype = torch.bfloat16 + logging.info("compute dtype is {}".format(dtype)) + + context_graph = None + if 'decoding-graph' in args.context_bias_mode: + context_graph = ContextGraph(args.context_list_path, + tokenizer.symbol_table, + configs['tokenizer_conf']['bpe_path'], + args.context_graph_score) + + _, blank_id = get_blank_id(configs, tokenizer.symbol_table) + logging.info("blank_id is {}".format(blank_id)) + + # TODO(Dinghao Zhou): Support RNN-T related decoding + # TODO(Lv Xiang): Support k2 related decoding + # TODO(Kaixun Huang): Support context graph + files = {} + modes = ['llmasr_decode'] + for mode in modes: + dir_name = os.path.join(args.result_dir, mode) + os.makedirs(dir_name, exist_ok=True) + file_name = os.path.join(dir_name, 'text') + files[mode] = open(file_name, 'w', encoding='utf-8') + max_format_len = max([len(mode) for mode in args.modes]) + + # Get prompt config + from gxl_ai_utils.utils import utils_file + global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt_stage4.yaml') + + with torch.cuda.amp.autocast(enabled=True, + dtype=dtype, + cache_enabled=False): + with torch.no_grad(): + # logging.info(f'utt_num: {utt_num}') + for batch_idx, batch in enumerate(test_data_loader): + keys = batch["keys"] + feats = batch["feats"].to(device) + target = batch["target"].to(device) + feats_lengths = batch["feats_lengths"].to(device) + target_lengths = batch["target_lengths"].to(device) + batch_size = feats.size(0) + + import random + if '><' in args.task: + args.task = args.task.replace('><', '> <') + if args.task == "" or args.task == "": + is_truncation = False + else: + is_truncation = True + random_index = random.randint(0, len(global_prompt_dict[args.task])-1) + prompt = global_prompt_dict[args.task][random_index] + # print(args.task, prompt) + + res_text = model.generate(wavs=feats, wavs_len=feats_lengths, prompt=prompt) + for mode in modes: + line = "{}\t{}".format(keys[0], res_text[0]) + files[mode].write(line+'\n') + utils_file.logging_print( '{} {} {}'.format(batch_idx, keys[0], res_text[0])) + if batch_idx % 100 == 0: + for mode, f in files.items(): + f.flush() # 强制将缓冲区内容刷新到文件 + # if batch_idx >= 1000 and is_truncation: + # utils_file.logging_info('采用截断至3000的策略') + # break + for mode, f in files.items(): + f.flush() # 强制将缓冲区内容刷新到文件 + f.close() + + +if __name__ == '__main__': + main() diff --git a/wenet/bin/recognize_onnx_gpu.py b/wenet/bin/recognize_onnx_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..373c3ddbec7fcac4d8c7cbf7c2549e89bdd88617 --- /dev/null +++ b/wenet/bin/recognize_onnx_gpu.py @@ -0,0 +1,297 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script is for testing exported onnx encoder and decoder from +export_onnx_gpu.py. The exported onnx models only support batch offline ASR inference. +It requires a python wrapped c++ ctc decoder. +Please install it by following: +https://github.com/Slyne/ctc_decoder.git +""" +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys + +import torch +import yaml +from torch.utils.data import DataLoader + +from wenet.dataset.dataset import Dataset +from wenet.utils.common import IGNORE_ID +from wenet.utils.config import override_config +from wenet.utils.init_tokenizer import init_tokenizer + +import onnxruntime as rt +import multiprocessing +import numpy as np + +try: + from swig_decoders import map_batch, \ + ctc_beam_search_decoder_batch, \ + TrieVector, PathTrie +except ImportError: + print('Please install ctc decoders first by refering to\n' + + 'https://github.com/Slyne/ctc_decoder.git') + sys.exit(1) + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--data_type', + default='raw', + choices=['raw', 'shard'], + help='train and cv data type') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--dict', required=True, help='dict file') + parser.add_argument('--encoder_onnx', + required=True, + help='encoder onnx file') + parser.add_argument('--decoder_onnx', + required=True, + help='decoder onnx file') + parser.add_argument('--result_file', required=True, help='asr result file') + parser.add_argument('--batch_size', + type=int, + default=32, + help='asr result file') + parser.add_argument('--mode', + choices=[ + 'ctc_greedy_search', 'ctc_prefix_beam_search', + 'attention_rescoring' + ], + default='attention_rescoring', + help='decoding mode') + parser.add_argument('--bpe_model', + default=None, + type=str, + help='bpe model for english part') + parser.add_argument('--override_config', + action='append', + default=[], + help="override yaml config") + parser.add_argument('--fp16', + action='store_true', + help='whether to export fp16 model, default false') + args = parser.parse_args() + return args + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if len(args.override_config) > 0: + configs = override_config(configs, args.override_config) + + reverse_weight = configs["model_conf"].get("reverse_weight", 0.0) + special_tokens = configs.get('tokenizer_conf', {}).get('special_tokens', None) + test_conf = copy.deepcopy(configs['dataset_conf']) + test_conf['filter_conf']['max_length'] = 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['filter_conf']['token_max_length'] = 102400 + test_conf['filter_conf']['token_min_length'] = 0 + test_conf['filter_conf']['max_output_input_ratio'] = 102400 + test_conf['filter_conf']['min_output_input_ratio'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['spec_sub'] = False + test_conf['spec_trim'] = False + test_conf['shuffle'] = False + test_conf['sort'] = False + test_conf['fbank_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_type'] = "static" + test_conf['batch_conf']['batch_size'] = args.batch_size + + tokenizer = init_tokenizer(configs) + test_dataset = Dataset(args.data_type, + args.test_data, + tokenizer, + test_conf, + partition=False) + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + # Init asr model from configs + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + if use_cuda: + EP_list = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + EP_list = ['CPUExecutionProvider'] + + encoder_ort_session = rt.InferenceSession(args.encoder_onnx, + providers=EP_list) + decoder_ort_session = None + if args.mode == "attention_rescoring": + decoder_ort_session = rt.InferenceSession(args.decoder_onnx, + providers=EP_list) + + # Load dict + vocabulary = [] + char_dict = {} + with open(args.dict, 'r') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + char_dict[int(arr[1])] = arr[0] + vocabulary.append(arr[0]) + + vocab_size = len(char_dict) + sos = (vocab_size - 1 if special_tokens is None else + special_tokens.get("", vocab_size - 1)) + eos = (vocab_size - 1 if special_tokens is None else + special_tokens.get("", vocab_size - 1)) + + with torch.no_grad(), open(args.result_file, 'w') as fout: + for _, batch in enumerate(test_data_loader): + keys = batch['keys'] + feats = batch['feats'] + feats_lengths = batch['feats_lengths'] + feats, feats_lengths = feats.numpy(), feats_lengths.numpy() + if args.fp16: + feats = feats.astype(np.float16) + ort_inputs = { + encoder_ort_session.get_inputs()[0].name: feats, + encoder_ort_session.get_inputs()[1].name: feats_lengths + } + ort_outs = encoder_ort_session.run(None, ort_inputs) + encoder_out, encoder_out_lens, ctc_log_probs, \ + beam_log_probs, beam_log_probs_idx = ort_outs + beam_size = beam_log_probs.shape[-1] + batch_size = beam_log_probs.shape[0] + num_processes = min(multiprocessing.cpu_count(), batch_size) + if args.mode == 'ctc_greedy_search': + if beam_size != 1: + log_probs_idx = beam_log_probs_idx[:, :, 0] + batch_sents = [] + for idx, seq in enumerate(log_probs_idx): + batch_sents.append(seq[0:encoder_out_lens[idx]].tolist()) + hyps = map_batch(batch_sents, vocabulary, num_processes, True, + 0) + elif args.mode in ('ctc_prefix_beam_search', + "attention_rescoring"): + batch_log_probs_seq_list = beam_log_probs.tolist() + batch_log_probs_idx_list = beam_log_probs_idx.tolist() + batch_len_list = encoder_out_lens.tolist() + batch_log_probs_seq = [] + batch_log_probs_ids = [] + batch_start = [] # only effective in streaming deployment + batch_root = TrieVector() + root_dict = {} + for i in range(len(batch_len_list)): + num_sent = batch_len_list[i] + batch_log_probs_seq.append( + batch_log_probs_seq_list[i][0:num_sent]) + batch_log_probs_ids.append( + batch_log_probs_idx_list[i][0:num_sent]) + root_dict[i] = PathTrie() + batch_root.append(root_dict[i]) + batch_start.append(True) + score_hyps = ctc_beam_search_decoder_batch( + batch_log_probs_seq, batch_log_probs_ids, batch_root, + batch_start, beam_size, num_processes, 0, -2, 0.99999) + if args.mode == 'ctc_prefix_beam_search': + hyps = [] + for cand_hyps in score_hyps: + hyps.append(cand_hyps[0][1]) + hyps = map_batch(hyps, vocabulary, num_processes, False, 0) + if args.mode == 'attention_rescoring': + ctc_score, all_hyps = [], [] + max_len = 0 + for hyps in score_hyps: + cur_len = len(hyps) + if len(hyps) < beam_size: + hyps += (beam_size - cur_len) * [(-float("INF"), + (0, ))] + cur_ctc_score = [] + for hyp in hyps: + cur_ctc_score.append(hyp[0]) + all_hyps.append(list(hyp[1])) + if len(hyp[1]) > max_len: + max_len = len(hyp[1]) + ctc_score.append(cur_ctc_score) + if args.fp16: + ctc_score = np.array(ctc_score, dtype=np.float16) + else: + ctc_score = np.array(ctc_score, dtype=np.float32) + hyps_pad_sos_eos = np.ones( + (batch_size, beam_size, max_len + 2), + dtype=np.int64) * IGNORE_ID + r_hyps_pad_sos_eos = np.ones( + (batch_size, beam_size, max_len + 2), + dtype=np.int64) * IGNORE_ID + hyps_lens_sos = np.ones((batch_size, beam_size), + dtype=np.int32) + k = 0 + for i in range(batch_size): + for j in range(beam_size): + cand = all_hyps[k] + l = len(cand) + 2 + hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos] + r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [ + eos + ] + hyps_lens_sos[i][j] = len(cand) + 1 + k += 1 + decoder_ort_inputs = { + decoder_ort_session.get_inputs()[0].name: encoder_out, + decoder_ort_session.get_inputs()[1].name: encoder_out_lens, + decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos, + decoder_ort_session.get_inputs()[3].name: hyps_lens_sos, + decoder_ort_session.get_inputs()[-1].name: ctc_score + } + if reverse_weight > 0: + r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs( + )[4].name + decoder_ort_inputs[ + r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos + best_index = decoder_ort_session.run(None, + decoder_ort_inputs)[0] + best_sents = [] + k = 0 + for idx in best_index: + cur_best_sent = all_hyps[k:k + beam_size][idx] + best_sents.append(cur_best_sent) + k += beam_size + hyps = map_batch(best_sents, vocabulary, num_processes) + + for i, key in enumerate(keys): + content = hyps[i] + logging.info('{} {}'.format(key, content)) + fout.write('{} {}\n'.format(key, content)) + +if __name__ == '__main__': + main() diff --git a/wenet/bin/train.py b/wenet/bin/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9edf684c9c1750fd3e1ec2611d7427f58391e3d3 --- /dev/null +++ b/wenet/bin/train.py @@ -0,0 +1,232 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import datetime +import logging +import os +import random + +import numpy as np +import yaml +import torch + +import torch.distributed as dist + +from torch.distributed.elastic.multiprocessing.errors import record +from wenet.utils.common import lrs_to_str, TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu + +from wenet.utils.executor import Executor +from wenet.utils.config import override_config +from wenet.utils.init_model import init_model +from wenet.utils.init_tokenizer import init_tokenizer +from wenet.utils.train_utils import ( + add_fsdp_args, add_model_args, add_dataset_args, add_ddp_args, + add_deepspeed_args, add_trace_args, init_distributed, + init_dataset_and_dataloader, check_modify_and_save_config, + init_optimizer_and_scheduler, init_scaler, trace_and_print_model, + wrap_cuda_model, init_summarywriter, save_model, log_per_epoch, + add_lora_args, reinit_lora) +from gxl_ai_utils.utils import utils_file + +try: + import torch_npu + + torch_npu.npu.conv.allow_hf32 = False + # import deepspeed_npu + from torch_npu.npu import amp + from torch_npu.contrib import transfer_to_npu +except ImportError: + utils_file.logging_warning( + "torch_npu is not installed, please install torch_npu first if you want to use torch_npu") +torch.backends.cudnn.allow_tf32 = False +torch.backends.cuda.matmul.allow_tf32 = False + +from msprobe.pytorch import seed_all +import gc + +gc.set_threshold(700, 10, 10000) # python gc阈值设置 + + +# import deepspeed_npu +def get_args(): + parser = argparse.ArgumentParser(description='training your network') + parser.add_argument('--train_engine', + default='torch_ddp', + choices=['torch_ddp', 'torch_fsdp', 'deepspeed'], + help='Engine for paralleled training') + # set default value of device to "cuda", avoiding the modify of original scripts + parser.add_argument('--device', + type=str, + default='cuda', + choices=["cpu", "npu", "cuda"], + help='accelerator for training') + # load deepspeed checkpoint + parser.add_argument('--load_dir', + type=str, + default=None) + parser.add_argument('--ckpt_id', + type=str, + default=None) + parser = add_model_args(parser) + parser = add_dataset_args(parser) + parser = add_ddp_args(parser) + parser = add_lora_args(parser) + parser = add_deepspeed_args(parser) + parser = add_fsdp_args(parser) + parser = add_trace_args(parser) + args = parser.parse_args() + if args.train_engine == "deepspeed": + args.deepspeed = True + assert args.deepspeed_config is not None + return args + + +# NOTE(xcsong): On worker errors, this recod tool will summarize the +# details of the error (e.g. time, rank, host, pid, traceback, etc). +@record +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + + # Set random seed + torch.manual_seed(777) + random.seed(777) + np.random.seed(777) + utils_file.logging_info('开始严格seed') + seed_all(777) + utils_file.logging_info('结束严格seed') + logging.info('Random seed set to {}'.format(777)) + + # Read config + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if len(args.override_config) > 0: + configs = override_config(configs, args.override_config) + + # init tokenizer + tokenizer = init_tokenizer(configs) + + # Init env for ddp OR deepspeed + _, _, rank = init_distributed(args) + + # Init asr model from configs + model, configs = init_model(args, configs) + + # Get dataset & dataloader + train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ + init_dataset_and_dataloader(args, configs, tokenizer) + + # Do some sanity checks and save config to arsg.model_dir + configs = check_modify_and_save_config(args, configs, + tokenizer.symbol_table) + + if hasattr(args, 'lora_reinit') and args.lora_reinit: + reinit_lora(model, args, configs, tokenizer) + + # Check model is jitable & print model archtectures + trace_and_print_model(args, model) + + # Tensorboard summary + writer = init_summarywriter(args) + + # Dispatch model from cpu to gpu + model, device = wrap_cuda_model(args, model, configs) + + # Get optimizer & scheduler + model, optimizer, scheduler = init_optimizer_and_scheduler( + args, configs, model) + + # Load deepspeed checkpoint + if args.load_dir is not None and \ + args.ckpt_id is not None: + _, client_sd = model.load_checkpoint(args.load_dir, args.ckpt_id) + + # Save checkpoints + # save_model(model, + # info_dict={ + # "save_time": + # datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), + # "tag": + # "init", + # **configs + # }) + + # Get executor + tag = configs["init_infos"].get("tag", "init") + executor = Executor(global_step=configs["init_infos"].get('step', -1), + device=device) + + # Init scaler, used for pytorch amp mixed precision training + scaler = init_scaler(args) + + # Start training loop + start_epoch = configs["init_infos"].get('epoch', 0) + int("epoch_" in tag) + # if save_interval in configs, steps mode else epoch mode + end_epoch = configs.get('max_epoch', 100) + assert start_epoch <= end_epoch + configs.pop("init_infos", None) + final_epoch = None + for epoch in range(start_epoch, end_epoch): + configs['epoch'] = epoch + + lrs = [group['lr'] for group in optimizer.param_groups] + logging.info('Epoch {} Step {} TRAIN info lr {} rank {}'.format( + epoch, executor.step, lrs_to_str(lrs), rank)) + + dist.barrier( + ) # NOTE(xcsong): Ensure all ranks start Train at the same time. + # NOTE(xcsong): Why we need a new group? see `train_utils.py::wenet_join` + group_join = dist.new_group( # fix by zhaoyi for 多机训练 + backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) + # group_join = None + executor.train(model, optimizer, scheduler, train_data_loader, + cv_data_loader, writer, configs, scaler, group_join) + # dist.destroy_process_group(group_join) + + dist.barrier( + ) # NOTE(xcsong): Ensure all ranks start CV at the same time. + loss_dict = executor.cv(model, cv_data_loader, configs) + info_dict = { + 'epoch': epoch, + 'lrs': [group['lr'] for group in optimizer.param_groups], + 'step': executor.step, + "loss_dict": loss_dict, + 'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), + 'tag': "epoch_{}".format(epoch), + 'loss_dict': loss_dict, + **configs + } + # epoch cv: tensorboard && log + log_per_epoch(writer, info_dict=info_dict) + save_model(model, info_dict=info_dict) + + final_epoch = epoch + + if final_epoch is not None and rank == 0: + final_model_path = os.path.join(args.model_dir, 'final.pt') + os.remove(final_model_path) if os.path.exists( + final_model_path) else None + os.symlink('{}.pt'.format(final_epoch), final_model_path) + writer.close() + dist.barrier( + ) # NOTE(yktian): Ensure all ranks end Train before destroy process group. + dist.destroy_process_group() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/wenet/branchformer/__init__.py b/wenet/branchformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wenet/branchformer/cgmlp.py b/wenet/branchformer/cgmlp.py new file mode 100644 index 0000000000000000000000000000000000000000..b56a2505e2512689503e70b77edfc84f08dafc99 --- /dev/null +++ b/wenet/branchformer/cgmlp.py @@ -0,0 +1,194 @@ +# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University) +# 2023 Voicecomm Inc (Kai Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""MLP with convolutional gating (cgMLP) definition. + +References: + https://openreview.net/forum?id=RA-zVvZLYIy + https://arxiv.org/abs/2105.08050 + +""" + +from typing import Tuple +import torch +import torch.nn as nn +from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES + + +class ConvolutionalSpatialGatingUnit(torch.nn.Module): + """Convolutional Spatial Gating Unit (CSGU).""" + + def __init__( + self, + size: int, + kernel_size: int, + dropout_rate: float, + use_linear_after_conv: bool, + gate_activation: str, + causal: bool = True, + ): + super().__init__() + + # split input channels + n_channels = size // 2 + self.norm = nn.LayerNorm(n_channels) + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.conv = torch.nn.Conv1d( + n_channels, + n_channels, + kernel_size, + 1, + padding, + groups=n_channels, + ) + if use_linear_after_conv: + self.linear = torch.nn.Linear(n_channels, n_channels) + else: + self.linear = None + + if gate_activation == "identity": + self.act = torch.nn.Identity() + else: + self.act = WENET_ACTIVATION_CLASSES[gate_activation]() + + self.dropout = torch.nn.Dropout(dropout_rate) + + def espnet_initialization_fn(self): + torch.nn.init.normal_(self.conv.weight, std=1e-6) + torch.nn.init.ones_(self.conv.bias) + if self.linear is not None: + torch.nn.init.normal_(self.linear.weight, std=1e-6) + torch.nn.init.ones_(self.linear.bias) + + def forward( + self, x: torch.Tensor, cache: torch.Tensor = torch.zeros((0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward method + + Args: + x (torch.Tensor): (batch, time, channels) + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + + Returns: + out (torch.Tensor): (batch, time, channels/2) + """ + + x_r, x_g = x.chunk(2, dim=-1) + # exchange the temporal dimension and the feature dimension + x_g = x_g.transpose(1, 2) # (#batch, channels, time) + + if self.lorder > 0: + if cache.size(2) == 0: # cache_t == 0 + x_g = nn.functional.pad(x_g, (self.lorder, 0), 'constant', 0.0) + else: + assert cache.size(0) == x_g.size(0) # equal batch + assert cache.size(1) == x_g.size(1) # equal channel + x_g = torch.cat((cache, x_g), dim=2) + assert (x_g.size(2) > self.lorder) + new_cache = x_g[:, :, -self.lorder:] + else: + # It's better we just return None if no cache is required, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = torch.zeros((0, 0, 0), + dtype=x_g.dtype, + device=x_g.device) + + x_g = x_g.transpose(1, 2) + x_g = self.norm(x_g) # (N, T, D/2) + x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2) + if self.linear is not None: + x_g = self.linear(x_g) + + x_g = self.act(x_g) + out = x_r * x_g # (N, T, D/2) + out = self.dropout(out) + return out, new_cache + + +class ConvolutionalGatingMLP(torch.nn.Module): + """Convolutional Gating MLP (cgMLP).""" + + def __init__( + self, + size: int, + linear_units: int, + kernel_size: int, + dropout_rate: float, + use_linear_after_conv: bool, + gate_activation: str, + causal: bool = True, + ): + super().__init__() + + self.channel_proj1 = torch.nn.Sequential( + torch.nn.Linear(size, linear_units), torch.nn.GELU()) + self.csgu = ConvolutionalSpatialGatingUnit( + size=linear_units, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + use_linear_after_conv=use_linear_after_conv, + gate_activation=gate_activation, + causal=causal, + ) + self.channel_proj2 = torch.nn.Linear(linear_units // 2, size) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + cache: torch.Tensor = torch.zeros((0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward method + + Args: + x (torch.Tensor): (batch, time, channels) + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. Not used yet + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + + Returns: + out (torch.Tensor): (batch, time, channels/2) + """ + + xs_pad = x + + # size -> linear_units + xs_pad = self.channel_proj1(xs_pad) + + # linear_units -> linear_units/2 + xs_pad, new_cnn_cache = self.csgu(xs_pad, cache) + + # linear_units/2 -> size + xs_pad = self.channel_proj2(xs_pad) + + out = xs_pad + + return out, new_cnn_cache diff --git a/wenet/branchformer/encoder.py b/wenet/branchformer/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2feda978e19ebbc1c0b4c1b6e94913e7c582e4c3 --- /dev/null +++ b/wenet/branchformer/encoder.py @@ -0,0 +1,177 @@ +# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University) +# 2023 Voicecomm Inc (Kai Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder definition.""" + +import torch + +from typing import List, Optional, Union + +from wenet.branchformer.encoder_layer import BranchformerEncoderLayer +from wenet.branchformer.cgmlp import ConvolutionalGatingMLP +from wenet.transformer.encoder import BaseEncoder +from wenet.utils.class_utils import ( + WENET_ATTENTION_CLASSES, ) + + +class BranchformerEncoder(BaseEncoder): + """Branchformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + use_attn: bool = True, + attention_heads: int = 4, + selfattention_layer_type: str = "rel_selfattn", + pos_enc_layer_type: str = "rel_pos", + use_cgmlp: bool = True, + cgmlp_linear_units: int = 2048, + cgmlp_conv_kernel: int = 31, + use_linear_after_conv: bool = False, + gate_activation: str = "identity", + merge_method: str = "concat", + cgmlp_weight: Union[float, List[float]] = 0.5, + attn_branch_drop_rate: Union[float, List[float]] = 0.0, + num_blocks: int = 12, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + stochastic_depth_rate: Union[float, List[float]] = 0.0, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + causal: bool = False, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + ): + super().__init__(input_size, output_size, attention_heads, + cgmlp_linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, True, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, gradient_checkpointing, + use_sdpa, layer_norm_type, norm_eps) + + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, + ) + + cgmlp_layer = ConvolutionalGatingMLP + cgmlp_layer_args = ( + output_size, + cgmlp_linear_units, + cgmlp_conv_kernel, + dropout_rate, + use_linear_after_conv, + gate_activation, + causal, + ) + + if isinstance(stochastic_depth_rate, float): + stochastic_depth_rate = [stochastic_depth_rate] * num_blocks + if len(stochastic_depth_rate) != num_blocks: + raise ValueError( + f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " + f"should be equal to num_blocks ({num_blocks})") + + if isinstance(cgmlp_weight, float): + cgmlp_weight = [cgmlp_weight] * num_blocks + if len(cgmlp_weight) != num_blocks: + raise ValueError( + f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to " + f"num_blocks ({num_blocks})") + + if isinstance(attn_branch_drop_rate, float): + attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks + if len(attn_branch_drop_rate) != num_blocks: + raise ValueError( + f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) " + f"should be equal to num_blocks ({num_blocks})") + + self.encoders = LayerDropModuleList( + p=stochastic_depth_rate, + modules=[ + BranchformerEncoderLayer( + output_size, + WENET_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args) if use_attn else None, + cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, + dropout_rate, + merge_method, + cgmlp_weight[lnum], + attn_branch_drop_rate[lnum], + stochastic_depth_rate[lnum], + ) for lnum in range(num_blocks) + ]) + + +# modify from : https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/layer_drop.py # noqa +class LayerDropModuleList(torch.nn.ModuleList): + """ + A LayerDrop implementation based on :class:`torch.nn.ModuleList`. + + We refresh the choice of which layers to drop every time we iterate + over the LayerDropModuleList instance. During evaluation we always + iterate over all layers. + + Usage:: + + layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) + for layer in layers: # this might iterate over layers 1 and 3 + x = layer(x) + for layer in layers: # this might iterate over all layers + x = layer(x) + for layer in layers: # this might not iterate over any layers + x = layer(x) + + Args: + p (float): probability of dropping out each layer + modules (iterable, optional): an iterable of modules to add + + Limitations: + 1 can work with ddp when layer's gradient checkpoint disabled + 2 can't work with ddp when layer's gradient checkpoint enables + 3 can work with fsdp + 4 can work with deepspeed + """ + + def __init__(self, p: List[float], modules=None): + super().__init__(modules) + assert len(p) == len(self) + self.p = p + + def __iter__(self): + dropout_probs = torch.empty(len(self)).uniform_() + for i, m in enumerate(super().__iter__()): + if not self.training or (dropout_probs[i] > self.p[i]): + yield m diff --git a/wenet/branchformer/encoder_layer.py b/wenet/branchformer/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..a48feefbd1d2e9d54394e32cd16a90c089e0ceae --- /dev/null +++ b/wenet/branchformer/encoder_layer.py @@ -0,0 +1,245 @@ +# Copyright (c) 2022 Yifan Peng (Carnegie Mellon University) +# 2023 Voicecomm Inc (Kai Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""BranchformerEncoderLayer definition.""" + +import torch +import torch.nn as nn +from typing import Optional, Tuple + +from wenet.transformer.attention import T_CACHE + + +class BranchformerEncoderLayer(torch.nn.Module): + """Branchformer encoder layer module. + + Args: + size (int): model dimension + attn: standard self-attention or efficient attention, optional + cgmlp: ConvolutionalGatingMLP, optional + dropout_rate (float): dropout probability + merge_method (str): concat, learned_ave, fixed_ave + cgmlp_weight (float): weight of the cgmlp branch, between 0 and 1, + used if merge_method is fixed_ave + attn_branch_drop_rate (float): probability of dropping the attn branch, + used if merge_method is learned_ave + stochastic_depth_rate (float): stochastic depth probability + """ + + def __init__( + self, + size: int, + attn: Optional[torch.nn.Module], + cgmlp: Optional[torch.nn.Module], + dropout_rate: float, + merge_method: str, + cgmlp_weight: float = 0.5, + attn_branch_drop_rate: float = 0.0, + stochastic_depth_rate: float = 0.0, + ): + super().__init__() + assert (attn is not None) or ( + cgmlp is not None), "At least one branch should be valid" + + self.size = size + self.attn = attn + self.cgmlp = cgmlp + self.merge_method = merge_method + self.cgmlp_weight = cgmlp_weight + self.attn_branch_drop_rate = attn_branch_drop_rate + self.stochastic_depth_rate = stochastic_depth_rate + self.use_two_branches = (attn is not None) and (cgmlp is not None) + + if attn is not None: + self.norm_mha = nn.LayerNorm(size) # for the MHA module + if cgmlp is not None: + self.norm_mlp = nn.LayerNorm(size) # for the MLP module + self.norm_final = nn.LayerNorm( + size) # for the final output of the block + + self.dropout = torch.nn.Dropout(dropout_rate) + + # # attention-based pooling for two branches + self.pooling_proj1 = torch.nn.Linear(size, 1) + self.pooling_proj2 = torch.nn.Linear(size, 1) + + # # linear projections for calculating merging weights + self.weight_proj1 = torch.nn.Linear(size, 1) + self.weight_proj2 = torch.nn.Linear(size, 1) + + if self.use_two_branches: + if self.merge_method == "concat": + self.merge_proj = torch.nn.Linear(size + size, size) + + elif self.merge_method == "learned_ave": + # linear projection after weighted average + self.merge_proj = torch.nn.Linear(size, size) + + elif self.merge_method == "fixed_ave": + assert (0.0 <= cgmlp_weight <= + 1.0), "cgmlp weight should be between 0.0 and 1.0" + + # remove the other branch if only one branch is used + if cgmlp_weight == 0.0: + self.use_two_branches = False + self.cgmlp = None + self.norm_mlp = None + elif cgmlp_weight == 1.0: + self.use_two_branches = False + self.attn = None + self.norm_mha = None + + # linear projection after weighted average + self.merge_proj = torch.nn.Linear(size, size) + else: + raise ValueError(f"unknown merge method: {merge_method}") + else: + self.merge_proj = torch.nn.Identity() + + def _forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: T_CACHE = (torch.zeros( + (0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + stoch_layer_coeff: float = 1.0 + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]: + # Two branches + x1 = x + x2 = x + + # Branch 1: multi-headed attention module + if self.attn is not None: + x1 = self.norm_mha(x1) + x_att, new_att_cache = self.attn(x1, x1, x1, mask, pos_emb, + att_cache) + x1 = self.dropout(x_att) + + # Branch 2: convolutional gating mlp + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.cgmlp is not None: + x2 = self.norm_mlp(x2) + x2, new_cnn_cache = self.cgmlp(x2, mask_pad, cnn_cache) + x2 = self.dropout(x2) + + # Merge two branches + if self.use_two_branches: + if self.merge_method == "concat": + x = x + stoch_layer_coeff * self.dropout( + self.merge_proj(torch.cat([x1, x2], dim=-1))) + elif self.merge_method == "learned_ave": + if (self.training and self.attn_branch_drop_rate > 0 + and torch.rand(1).item() < self.attn_branch_drop_rate): + # Drop the attn branch + w1, w2 = torch.tensor(0.0), torch.tensor(1.0) + else: + # branch1 + score1 = (self.pooling_proj1(x1).transpose(1, 2) / + self.size**0.5) + score1 = score1.masked_fill(mask_pad.eq(0), -float('inf')) + score1 = torch.softmax(score1, dim=-1).masked_fill( + mask_pad.eq(0), 0.0) + + pooled1 = torch.matmul(score1, + x1).squeeze(1) # (batch, size) + weight1 = self.weight_proj1(pooled1) # (batch, 1) + + # branch2 + score2 = (self.pooling_proj2(x2).transpose(1, 2) / + self.size**0.5) + score2 = score2.masked_fill(mask_pad.eq(0), -float('inf')) + score2 = torch.softmax(score2, dim=-1).masked_fill( + mask_pad.eq(0), 0.0) + + pooled2 = torch.matmul(score2, + x2).squeeze(1) # (batch, size) + weight2 = self.weight_proj2(pooled2) # (batch, 1) + + # normalize weights of two branches + merge_weights = torch.softmax(torch.cat([weight1, weight2], + dim=-1), + dim=-1) # (batch, 2) + merge_weights = merge_weights.unsqueeze(-1).unsqueeze( + -1) # (batch, 2, 1, 1) + w1, w2 = merge_weights[:, + 0], merge_weights[:, + 1] # (batch, 1, 1) + + x = x + stoch_layer_coeff * self.dropout( + self.merge_proj(w1 * x1 + w2 * x2)) + elif self.merge_method == "fixed_ave": + x = x + stoch_layer_coeff * self.dropout( + self.merge_proj((1.0 - self.cgmlp_weight) * x1 + + self.cgmlp_weight * x2)) + else: + raise RuntimeError( + f"unknown merge method: {self.merge_method}") + else: + if self.attn is None: + x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x2)) + elif self.cgmlp is None: + x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x1)) + else: + # This should not happen + raise RuntimeError( + "Both branches are not None, which is unexpected.") + + x = self.norm_final(x) + + return x, mask, new_att_cache, new_cnn_cache + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: T_CACHE = (torch.zeros( + (0, 0, 0, 0)), torch.zeros(0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE, torch.Tensor]: + """Compute encoded features. + + Args: + x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time, time). + pos_emb (torch.Tensor): positional encoding, must not be None + for BranchformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in cgmlp layer + (#batch=1, size, cache_t2) + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time. + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + stoch_layer_coeff = 1.0 + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + if self.training: + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + return self._forward(x, mask, pos_emb, mask_pad, att_cache, cnn_cache, + stoch_layer_coeff) diff --git a/wenet/cli/__init__.py b/wenet/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wenet/cli/hub.py b/wenet/cli/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..b8ca91ad5affa08f256d33d607dc28ff08601374 --- /dev/null +++ b/wenet/cli/hub.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022 Mddct(hamddct@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import requests +import sys +import tarfile +from pathlib import Path +from urllib.request import urlretrieve + +import tqdm + + +def download(url: str, dest: str, only_child=True): + """ download from url to dest + """ + assert os.path.exists(dest) + print('Downloading {} to {}'.format(url, dest)) + + def progress_hook(t): + last_b = [0] + + def update_to(b=1, bsize=1, tsize=None): + if tsize not in (None, -1): + t.total = tsize + displayed = t.update((b - last_b[0]) * bsize) + last_b[0] = b + return displayed + + return update_to + + # *.tar.gz + name = url.split('?')[0].split('/')[-1] + tar_path = os.path.join(dest, name) + with tqdm.tqdm(unit='B', + unit_scale=True, + unit_divisor=1024, + miniters=1, + desc=(name)) as t: + urlretrieve(url, + filename=tar_path, + reporthook=progress_hook(t), + data=None) + t.total = t.n + + with tarfile.open(tar_path) as f: + if not only_child: + f.extractall(dest) + else: + for tarinfo in f: + if "/" not in tarinfo.name: + continue + name = os.path.basename(tarinfo.name) + fileobj = f.extractfile(tarinfo) + with open(os.path.join(dest, name), "wb") as writer: + writer.write(fileobj.read()) + + +class Hub(object): + """Hub for wenet pretrain runtime model + """ + # TODO(Mddct): make assets class to support other language + Assets = { + # wenetspeech + "chinese": "wenetspeech_u2pp_conformer_libtorch.tar.gz", + # gigaspeech + "english": "gigaspeech_u2pp_conformer_libtorch.tar.gz", + # paraformer + "paraformer": "paraformer.tar.gz" + } + + def __init__(self) -> None: + pass + + @staticmethod + def get_model_by_lang(lang: str) -> str: + if lang not in Hub.Assets.keys(): + print('ERROR: Unsupported language {} !!!'.format(lang)) + sys.exit(1) + + # NOTE(Mddct): model_dir structure + # Path.Home()/.wenet + # - chs + # - units.txt + # - final.zip + # - en + # - units.txt + # - final.zip + model = Hub.Assets[lang] + model_dir = os.path.join(Path.home(), ".wenet", lang) + if not os.path.exists(model_dir): + os.makedirs(model_dir) + # TODO(Mddct): model metadata + if set(["final.zip", + "units.txt"]).issubset(set(os.listdir(model_dir))): + return model_dir + # If not exist, download + response = requests.get( + "https://modelscope.cn/api/v1/datasets/wenet/wenet_pretrained_models/oss/tree" # noqa + ) + model_info = next(data for data in response.json()["Data"] + if data["Key"] == model) + model_url = model_info['Url'] + download(model_url, model_dir, only_child=True) + return model_dir diff --git a/wenet/cli/model.py b/wenet/cli/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bb24bdb3379f58fbb9c6507fedc8da3f2e8fb68f --- /dev/null +++ b/wenet/cli/model.py @@ -0,0 +1,176 @@ +# Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi + +from wenet.cli.hub import Hub +from wenet.utils.ctc_utils import (force_align, gen_ctc_peak_time, + gen_timestamps_from_peak) +from wenet.utils.file_utils import read_symbol_table +from wenet.transformer.search import (attention_rescoring, + ctc_prefix_beam_search, DecodeResult) +from wenet.utils.context_graph import ContextGraph +from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu + + +class Model: + + def __init__(self, + model_dir: str, + gpu: int = -1, + beam: int = 5, + context_path: str = None, + context_score: float = 6.0, + resample_rate: int = 16000): + model_path = os.path.join(model_dir, 'final.zip') + units_path = os.path.join(model_dir, 'units.txt') + self.model = torch.jit.load(model_path) + self.resample_rate = resample_rate + self.model.eval() + if gpu >= 0: + device = 'cuda:{}'.format(gpu) + else: + device = 'cpu' + self.device = torch.device(device) + self.model.to(device) + self.symbol_table = read_symbol_table(units_path) + self.char_dict = {v: k for k, v in self.symbol_table.items()} + self.beam = beam + if context_path is not None: + self.context_graph = ContextGraph(context_path, + self.symbol_table, + context_score=context_score) + else: + self.context_graph = None + + def compute_feats(self, audio_file: str) -> torch.Tensor: + waveform, sample_rate = torchaudio.load(audio_file, normalize=False) + waveform = waveform.to(torch.float) + if sample_rate != self.resample_rate: + waveform = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=self.resample_rate)(waveform) + # NOTE (MengqingCao): complex dtype not supported in torch_npu.abs() now, + # thus, delay placing data on NPU after the calculation of fbank. + # revert me after complex dtype is supported. + if "npu" not in self.device.__str__(): + waveform = waveform.to(self.device) + feats = kaldi.fbank(waveform, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + energy_floor=0.0, + sample_frequency=self.resample_rate) + if "npu" in self.device.__str__(): + feats = feats.to(self.device) + feats = feats.unsqueeze(0) + return feats + + @torch.no_grad() + def _decode(self, + audio_file: str, + tokens_info: bool = False, + label: str = None) -> dict: + feats = self.compute_feats(audio_file) + encoder_out, _, _ = self.model.forward_encoder_chunk(feats, 0, -1) + encoder_lens = torch.tensor([encoder_out.size(1)], + dtype=torch.long, + device=encoder_out.device) + ctc_probs = self.model.ctc_activation(encoder_out) + if label is None: + ctc_prefix_results = ctc_prefix_beam_search( + ctc_probs, + encoder_lens, + self.beam, + context_graph=self.context_graph) + else: # force align mode, construct ctc prefix result from alignment + label_t = self.tokenize(label) + alignment = force_align(ctc_probs.squeeze(0), + torch.tensor(label_t, dtype=torch.long)) + peaks = gen_ctc_peak_time(alignment) + ctc_prefix_results = [ + DecodeResult(tokens=label_t, + score=0.0, + times=peaks, + nbest=[label_t], + nbest_scores=[0.0], + nbest_times=[peaks]) + ] + rescoring_results = attention_rescoring(self.model, ctc_prefix_results, + encoder_out, encoder_lens, 0.3, + 0.5) + res = rescoring_results[0] + result = {} + result['text'] = ''.join([self.char_dict[x] for x in res.tokens]) + result['confidence'] = res.confidence + + if tokens_info: + frame_rate = self.model.subsampling_rate( + ) * 0.01 # 0.01 seconds per frame + max_duration = encoder_out.size(1) * frame_rate + times = gen_timestamps_from_peak(res.times, max_duration, + frame_rate, 1.0) + tokens_info = [] + for i, x in enumerate(res.tokens): + tokens_info.append({ + 'token': self.char_dict[x], + 'start': round(times[i][0], 3), + 'end': round(times[i][1], 3), + 'confidence': round(res.tokens_confidence[i], 2) + }) + result['tokens'] = tokens_info + return result + + def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: + return self._decode(audio_file, tokens_info) + + def tokenize(self, label: str): + # TODO(Binbin Zhang): Support BPE + tokens = [] + for c in label: + if c == ' ': + c = "▁" + tokens.append(c) + token_list = [] + for c in tokens: + if c in self.symbol_table: + token_list.append(self.symbol_table[c]) + elif '' in self.symbol_table: + token_list.append(self.symbol_table['']) + return token_list + + def align(self, audio_file: str, label: str) -> dict: + return self._decode(audio_file, True, label) + + +def load_model(language: str = None, + model_dir: str = None, + gpu: int = -1, + beam: int = 5, + context_path: str = None, + context_score: float = 6.0, + device: str = "cpu") -> Model: + if model_dir is None: + model_dir = Hub.get_model_by_lang(language) + + if gpu != -1: + # remain the original usage of gpu + device = "cuda" + model = Model(model_dir, gpu, beam, context_path, context_score) + model.device = torch.device(device) + model.model.to(device) + return model diff --git a/wenet/cli/paraformer_model.py b/wenet/cli/paraformer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f834ab25446b898c777ee821bbe5c03b5f3fbe --- /dev/null +++ b/wenet/cli/paraformer_model.py @@ -0,0 +1,82 @@ +import os + +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi + +from wenet.cli.hub import Hub +from wenet.paraformer.search import (gen_timestamps_from_peak, + paraformer_greedy_search) +from wenet.text.paraformer_tokenizer import ParaformerTokenizer +from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu + + +class Paraformer: + + def __init__(self, model_dir: str, resample_rate: int = 16000) -> None: + + model_path = os.path.join(model_dir, 'final.zip') + units_path = os.path.join(model_dir, 'units.txt') + self.model = torch.jit.load(model_path) + self.resample_rate = resample_rate + self.device = torch.device("cpu") + self.tokenizer = ParaformerTokenizer(symbol_table=units_path) + + def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: + waveform, sample_rate = torchaudio.load(audio_file, normalize=False) + waveform = waveform.to(torch.float).to(self.device) + if sample_rate != self.resample_rate: + waveform = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=self.resample_rate)(waveform) + feats = kaldi.fbank(waveform, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + energy_floor=0.0, + sample_frequency=self.resample_rate, + window_type="hamming") + feats = feats.unsqueeze(0) + feats_lens = torch.tensor([feats.size(1)], + dtype=torch.int64, + device=feats.device) + + decoder_out, token_num, tp_alphas = self.model.forward_paraformer( + feats, feats_lens) + cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num) + res = paraformer_greedy_search(decoder_out, token_num, cif_peaks)[0] + result = {} + result['confidence'] = res.confidence + result['text'] = self.tokenizer.detokenize(res.tokens)[0] + if tokens_info: + tokens_info = [] + times = gen_timestamps_from_peak(res.times, + num_frames=tp_alphas.size(1), + frame_rate=0.02) + + for i, x in enumerate(res.tokens): + tokens_info.append({ + 'token': self.tokenizer.char_dict[x], + 'start': round(times[i][0], 3), + 'end': round(times[i][1], 3), + 'confidence': round(res.tokens_confidence[i], 2) + }) + result['tokens'] = tokens_info + + return result + + def align(self, audio_file: str, label: str) -> dict: + raise NotImplementedError("Align is currently not supported") + + +def load_model(model_dir: str = None, + gpu: int = -1, + device: str = "cpu") -> Paraformer: + if model_dir is None: + model_dir = Hub.get_model_by_lang('paraformer') + if gpu != -1: + # remain the original usage of gpu + device = "cuda" + paraformer = Paraformer(model_dir) + paraformer.device = torch.device(device) + paraformer.model.to(device) + return paraformer diff --git a/wenet/cli/transcribe.py b/wenet/cli/transcribe.py new file mode 100644 index 0000000000000000000000000000000000000000..28bf27919273f06a343f3faa824dc710339a6077 --- /dev/null +++ b/wenet/cli/transcribe.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from wenet.cli.paraformer_model import load_model as load_paraformer +from wenet.cli.model import load_model + + +def get_args(): + parser = argparse.ArgumentParser(description='') + parser.add_argument('audio_file', help='audio file to transcribe') + parser.add_argument('-l', + '--language', + choices=[ + 'chinese', + 'english', + ], + default='chinese', + help='language type') + parser.add_argument('-m', + '--model_dir', + default=None, + help='specify your own model dir') + parser.add_argument('-g', + '--gpu', + type=int, + default='-1', + help='gpu id to decode, default is cpu.') + parser.add_argument('--device', + type=str, + default='cpu', + choices=["cpu", "npu", "cuda"], + help='accelerator to use') + parser.add_argument('-t', + '--show_tokens_info', + action='store_true', + help='whether to output token(word) level information' + ', such times/confidence') + parser.add_argument('--align', + action='store_true', + help='force align the input audio and transcript') + parser.add_argument('--label', type=str, help='the input label to align') + parser.add_argument('--paraformer', + action='store_true', + help='whether to use the best chinese model') + parser.add_argument('--beam', type=int, default=5, help="beam size") + parser.add_argument('--context_path', + type=str, + default=None, + help='context list file') + parser.add_argument('--context_score', + type=float, + default=6.0, + help='context score') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + if args.paraformer: + model = load_paraformer(args.model_dir, args.gpu, args.device) + else: + model = load_model(args.language, args.model_dir, args.gpu, args.beam, + args.context_path, args.context_score, args.device) + if args.align: + result = model.align(args.audio_file, args.label) + else: + result = model.transcribe(args.audio_file, args.show_tokens_info) + print(result) + + +if __name__ == "__main__": + main() diff --git a/wenet/ctl_model/asr_model_ctl.py b/wenet/ctl_model/asr_model_ctl.py new file mode 100644 index 0000000000000000000000000000000000000000..6e9bc810a7c010432d0f70c423d14c59c8bbcd79 --- /dev/null +++ b/wenet/ctl_model/asr_model_ctl.py @@ -0,0 +1,277 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2023 NetEase Inc. (authors: Yuting Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) and +# fairseq(https://github.com/facebookresearch/fairseq) + +from typing import Dict, Optional + +import torch +import torch.nn.functional as F +from wenet.transformer.ctc import CTC +from wenet.transformer.decoder import TransformerDecoder +from wenet.ctl_model.encoder import TransformerEncoder +from wenet.transformer.asr_model import ASRModel +from wenet.utils.common import IGNORE_ID + + +class CTLModel(ASRModel): + """ + Implementation of Interspeecch 2023 paper: + 'Enhancing the Unified Streaming and Non-streaming Model + with Contrastive Learning' + https://arxiv.org/abs/2306.00755 + """ + + def __init__( + self, + vocab_size: int, + encoder: TransformerEncoder, + decoder: TransformerDecoder, + ctc: CTC, + ctc_weight: float = 0.5, + ignore_id: int = IGNORE_ID, + reverse_weight: float = 0.0, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + logit_temp: float = 0.1, + n_negatives: int = 0, + ctl_weight: float = 1, + special_tokens: dict = None, + ): + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + super().__init__(vocab_size, + encoder, + decoder, + ctc, + ctc_weight, + ignore_id, + reverse_weight, + lsm_weight, + length_normalized_loss, + special_tokens=special_tokens) + + # For CTL Loss + self.n_negatives = n_negatives + self.ctl_weight = ctl_weight + self.logit_temp = logit_temp + + @torch.jit.unused + def forward( + self, + batch: dict, + device: torch.device, + ) -> Dict[str, Optional[torch.Tensor]]: + + speech = batch['feats'].to(device) + speech_lengths = batch['feats_lengths'].to(device) + text = batch['target'].to(device) + text_lengths = batch['target_lengths'].to(device) + loss_full, encoder_out_full, _, _ = self.forward_full( + speech, speech_lengths, text, text_lengths) + loss_chunk, encoder_out, lens_chunk, encoder_mask = self.forward_chunk( + speech, speech_lengths, text, text_lengths) + + ctl_loss = 0.0 + if self.ctl_weight > 0 and self.n_negatives > 0: + num = encoder_out_full.size(1) + targets = encoder_out_full + src = encoder_out + negs, negs_idxs = self.sample_negatives(targets, + targets.size(1), + speech_lengths=lens_chunk) + ctl_loss = self.CTL(src, targets, negs, encoder_mask) + + loss = loss_full + loss_chunk + self.ctl_weight * ctl_loss + return { + "loss": loss, + "loss_full": loss_full, + "loss_chunk": loss_chunk, + "loss_ctl": ctl_loss + } + + def forward_full( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ): + """Full context mode + Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == + text_lengths.shape[0]), (speech.shape, speech_lengths.shape, + text.shape, text_lengths.shape) + # 1. Encoder + encoder_out, encoder_mask = self.encoder.forward_full( + speech, speech_lengths) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + + # 2a. Attention-decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, + text, text_lengths) + else: + loss_att = None + + # 2b. CTC branch + if self.ctc_weight != 0.0: + loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, + text_lengths) + else: + loss_ctc = None + + if loss_ctc is None: + loss = loss_att + elif loss_att is None: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc[0] + (1 - + self.ctc_weight) * loss_att + return loss, encoder_out, encoder_out_lens, encoder_mask + + def forward_chunk( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ): + """Chunk-based context mode + Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == + text_lengths.shape[0]), (speech.shape, speech_lengths.shape, + text.shape, text_lengths.shape) + # 1. Encoder + encoder_out, encoder_mask = self.encoder(speech, speech_lengths) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + + # 2a. Attention-decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, + text, text_lengths) + else: + loss_att = None + + # 2b. CTC branch + if self.ctc_weight != 0.0: + loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, + text_lengths) + else: + loss_ctc = None + + if loss_ctc is None: + loss = loss_att + elif loss_att is None: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc[0] + (1 - + self.ctc_weight) * loss_att + return loss, encoder_out, encoder_out_lens, encoder_mask + + def sample_negatives(self, y, num, padding_count=0, speech_lengths=None): + if self.n_negatives == 0: + return y.new(0) + bsz, tsz, fsz = y.shape + y = y.reshape(-1, fsz) # BTC => (BxT)C + + # FIXME: what happens if padding_count is specified? + high = tsz - (padding_count or 0) + with torch.no_grad(): + assert high > 1, f"{bsz,tsz,fsz}" + + if self.n_negatives > 0: + tszs = (torch.arange(num).unsqueeze(-1).expand( + -1, self.n_negatives).flatten()) + if speech_lengths is not None: + neg_idxs = [ + torch.randint(low=0, + high=speech_lengths[i].item() - 1, + size=(1, self.n_negatives * tsz)) + for i in range(len(speech_lengths)) + ] + neg_idxs = torch.cat(neg_idxs).reshape( + bsz, self.n_negatives * tsz) + else: + neg_idxs = torch.randint(low=0, + high=num - 1, + size=(bsz, + self.n_negatives * tsz)) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high) + + negs = y[neg_idxs.view(-1)] + negs = negs.contiguous().view(bsz, num, self.n_negatives, + fsz).permute(2, 0, 1, 3) # to NxBxTxC + return negs, neg_idxs + + def compute_preds(self, x, y, negatives): + neg_is_pos = (y == negatives).all(-1) + y = y.unsqueeze(0) + targets = torch.cat([y, negatives], dim=0) + + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1) + logits = logits / self.logit_temp + logits = logits.type_as(x) + + if neg_is_pos.any(): + if not hasattr(self, "_inftensor"): + self._inftensor = float("-inf") + # logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) + logits[1:][neg_is_pos] = self._inftensor + logits = logits.transpose(0, 2) + logits = logits.transpose(0, 1) + logits = logits.reshape(-1, logits.size(-1)) + return logits + + def CTL(self, x, y, negs, mask=None): + # Step1: compute cosine similarity, shape [B*T, n_negatives+1] + logits = self.compute_preds(x, y, negs) + + # Step2: target shape [B*T] + target = x.new_zeros(x.size(0) * x.size(1), dtype=torch.long) + + # Step3: compute CTL loss + if mask is not None: + normalize_length = mask.sum() + bz, sz = mask.size(0), mask.size(-1) + mask = mask.squeeze(1).reshape(bz * sz).eq(0) + ce = F.cross_entropy(logits, target, reduction='none') + loss = ce.masked_fill(mask, 0).sum() / normalize_length + else: + loss = F.cross_entropy(logits, target) + + return loss diff --git a/wenet/ctl_model/encoder.py b/wenet/ctl_model/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9aa18b7048a922e0bd9725cfe867083a5bf5b26e --- /dev/null +++ b/wenet/ctl_model/encoder.py @@ -0,0 +1,172 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2023 NetEase Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) +"""Encoder definition.""" +from typing import Optional, Tuple + +import torch + +from wenet.utils.mask import make_pad_mask +from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder + + +class DualTransformerEncoder(TransformerEncoder): + """Transformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + activation_type: str = "relu", + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + selfattention_layer_type: str = "selfattn", + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, + ): + """ Construct DualTransformerEncoder + Support both the full context mode and the streaming mode separately + """ + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, query_bias, key_bias, + value_bias, activation_type, gradient_checkpointing, + use_sdpa, layer_norm_type, norm_eps, n_kv_head, + head_dim, selfattention_layer_type, mlp_type, + mlp_bias, n_expert, n_expert_activated) + + def forward_full( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + for layer in self.encoders: + xs, masks, _, _ = layer(xs, masks, pos_emb, mask_pad) + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks + + +class DualConformerEncoder(ConformerEncoder): + """Conformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = True, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + conv_bias: bool = True, + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, + ): + """ Construct DualConformerEncoder + Support both the full context mode and the streaming mode separately + """ + super().__init__( + input_size, output_size, attention_heads, linear_units, num_blocks, + dropout_rate, positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + static_chunk_size, use_dynamic_chunk, global_cmvn, + use_dynamic_left_chunk, positionwise_conv_kernel_size, + macaron_style, selfattention_layer_type, activation_type, + use_cnn_module, cnn_module_kernel, causal, cnn_module_norm, + query_bias, key_bias, value_bias, conv_bias, + gradient_checkpointing, use_sdpa, layer_norm_type, norm_eps, + n_kv_head, head_dim, mlp_type, mlp_bias, n_expert, + n_expert_activated) + + def forward_full( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + for layer in self.encoders: + xs, masks, _, _ = layer(xs, masks, pos_emb, mask_pad) + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks diff --git a/wenet/dataset/__init__.py b/wenet/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wenet/dataset/datapipes.py b/wenet/dataset/datapipes.py new file mode 100644 index 0000000000000000000000000000000000000000..54127a8214f0325b8821356d7ec3c73b3a4e5741 --- /dev/null +++ b/wenet/dataset/datapipes.py @@ -0,0 +1,470 @@ +# Copyright (c) 2023 Wenet Community. (authors: Dinghao Zhou) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from collections.abc import Callable +import copy +import sys +import tarfile +import logging +from typing import List, Optional +import numpy as np +import torch +from torch.utils.data import IterDataPipe, functional_datapipe +from torch.utils.data import datapipes +from torch.utils.data.datapipes.iter import Mapper +from torch.utils.data.datapipes.iter.sharding import ( + SHARDING_PRIORITIES, ShardingFilterIterDataPipe) +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn + +from wenet.dataset.processor import parse_url + + +@functional_datapipe("map_ignore_error") +class MapperIgnoreErrorDataPipe(Mapper): + + def __init__(self, + dataset: IterDataPipe, + fn: Callable, + input_col=None, + output_col=None, + log_error: bool = True) -> None: + super().__init__(dataset, fn, input_col, output_col) + self._iter = None + self.log_error = log_error + + def __iter__(self): + if self._iter is None: + self._iter = iter(self.datapipe) + + while True: + try: + elem = next(self._iter) + yield self._apply_fn(elem) + except StopIteration: + self._iter = None + return + except Exception as ex: + if self.log_error: + logging.warning(str(ex)) + + +@functional_datapipe('bucket_by_sequence_length') +class BucketBySequenceLengthDataPipe(IterDataPipe): + + def __init__( + self, + dataset: IterDataPipe, + elem_length_func, + bucket_boundaries: List[int], + bucket_batch_sizes: List[int], + wrapper_class=None, + ) -> None: + super().__init__() + _check_unpickable_fn(elem_length_func) + assert len(bucket_batch_sizes) == len(bucket_boundaries) + 1 + self.bucket_batch_sizes = bucket_batch_sizes + self.bucket_boundaries = bucket_boundaries + [sys.maxsize] + self.elem_length_func = elem_length_func + + self._group_dp = GroupByWindowDataPipe(dataset, + self._element_to_bucket_id, + self._window_size_func, + wrapper_class=wrapper_class) + + def __iter__(self): + yield from self._group_dp + + def _element_to_bucket_id(self, elem): + seq_len = self.elem_length_func(elem) + bucket_id = 0 + for (i, b) in enumerate(self.bucket_boundaries): + if seq_len < b: + bucket_id = i + break + return bucket_id + + def _window_size_func(self, bucket_id): + return self.bucket_batch_sizes[bucket_id] + + +@functional_datapipe("group_by_window") +class GroupByWindowDataPipe(datapipes.iter.Grouper): + + def __init__( + self, + dataset: IterDataPipe, + key_func, + window_size_func, + wrapper_class=None, + ): + super().__init__(dataset, + key_func, + keep_key=False, + group_size=None, + drop_remaining=False) + _check_unpickable_fn(window_size_func) + self.dp = dataset + self.window_size_func = window_size_func + if wrapper_class is not None: + _check_unpickable_fn(wrapper_class) + del self.wrapper_class + self.wrapper_class = wrapper_class + + def __iter__(self): + for x in self.datapipe: + key = self.group_key_fn(x) + + self.buffer_elements[key].append(x) + self.curr_buffer_size += 1 + + group_size = self.window_size_func(key) + if group_size == len(self.buffer_elements[key]): + result = self.wrapper_class(self.buffer_elements[key]) + yield result + self.curr_buffer_size -= len(self.buffer_elements[key]) + del self.buffer_elements[key] + + if self.curr_buffer_size == self.max_buffer_size: + result_to_yield = self._remove_biggest_key() + if result_to_yield is not None: + result = self.wrapper_class(result_to_yield) + yield result + + for key in tuple(self.buffer_elements.keys()): + result = self.wrapper_class(self.buffer_elements.pop(key)) + self.curr_buffer_size -= len(result) + yield result + + +@functional_datapipe("sort") +class SortDataPipe(IterDataPipe): + + def __init__(self, + dataset: IterDataPipe, + buffer_size: int = 500, + key_func=None, + reverse=False) -> None: + if key_func is not None: + _check_unpickable_fn(key_func) + self.buffer_size = buffer_size + super().__init__() + self.dp = dataset + self._buffer = [] + self.key_func = key_func + self.reverse = reverse + + def __iter__(self): + for elem in self.dp: + self._buffer.append(elem) + if len(self._buffer) >= self.buffer_size: + self._buffer.sort(key=self.key_func, reverse=self.reverse) + for x in self._buffer: + yield x + del self._buffer + self._buffer = [] + # The sample left over + self._buffer.sort(key=self.key_func, reverse=self.reverse) + for x in self._buffer: + yield x + del self._buffer + self._buffer = [] + + +@functional_datapipe("dynamic_batch") +class DynamicBatchDataPipe(IterDataPipe): + + def __init__(self, dataset: IterDataPipe, window_class, + wrapper_class) -> None: + _check_unpickable_fn(window_class) + _check_unpickable_fn(wrapper_class) + super().__init__() + self.dp = dataset + assert window_class is not None + assert wrapper_class is not None + self.window_class = window_class + self._buffer = [] + self._wrappr_class = wrapper_class + + def __iter__(self): + for elem in self.dp: + if not self.window_class(elem, len(self._buffer)): + self._buffer.append(elem) + else: + if len(self._buffer) > 0: + yield self._wrappr_class(self._buffer) + del self._buffer + self._buffer = [elem] + if len(self._buffer) > 0: + yield self._wrappr_class(self._buffer) + del self._buffer + self._buffer = [] + + +@functional_datapipe("prefetch") +class PrefetchDataPipe(IterDataPipe): + """Performs prefetching""" + + def __init__( + self, + dataset: IterDataPipe, + buffer_size: int = 500, + ): + # TODO(Mddct): support multiprocessing pool with shared-memory to + # prefetch + super().__init__() + self.dp = dataset + self._iter = None + self._prefetch_buffer_size = buffer_size + self._buffer = None + if self._prefetch_buffer_size > 0: + self._buffer = collections.deque(maxlen=self._prefetch_buffer_size) + + def __iter__(self): + if self._prefetch_buffer_size > 0: + if self._iter is None: + self._iter = iter(self.dp) + assert self._buffer is not None + + while True: + if len(self._buffer) <= self._prefetch_buffer_size // 2: + while len(self._buffer) < self._prefetch_buffer_size: + try: + self._buffer.append(next(self._iter)) + except StopIteration: + if len(self._buffer) != 0: + while len(self._buffer) > 0: + yield self._buffer.popleft() + self._iter = None + return + while len(self._buffer) > self._prefetch_buffer_size // 2: + elem = self._buffer.popleft() + yield elem + + else: + yield from self.dp + + +@functional_datapipe("repeat") +class RepeatDatapipe(IterDataPipe): + + def __init__(self, dataset: IterDataPipe, count: int = -1): + super().__init__() + self.dp = dataset + self.count = count + + def __iter__(self): + if self.count == 1: + yield from self.dp + return + i = 0 + while self.count < 0 or i < self.count: + for elem in self.dp: + new_elem = copy.copy(elem) + yield new_elem + i += 1 + + +@functional_datapipe("shard") +class ShardDataPipe(ShardingFilterIterDataPipe): + + def __init__(self, dataset: IterDataPipe, partition: bool = False): + super().__init__(dataset, None) + self.partition = partition + self.dp = dataset + + def apply_sharding(self, num_of_instances: int, instance_id: int, + sharding_group: SHARDING_PRIORITIES): + if self.partition: + return super().apply_sharding(num_of_instances, instance_id, + sharding_group) + else: + # We can not handle uneven data for CV on DDP, so we don't + # sample data by rank, that means every GPU gets the same + # and all the CV data + info = torch.utils.data.get_worker_info() + if info is None: + self.num_of_instances = 1 + self.instance_id = 0 + else: + n_workers_per_device = info.num_workers + self.num_of_instances = n_workers_per_device + self.instance_id = info.id + + +@functional_datapipe("interleave") +class InterlaveDataPipe(IterDataPipe): + + def __init__( + self, + source_datapipes: List[IterDataPipe], + weights: Optional[List[float]] = None, + seed=2027, + ): + super().__init__() + self.rng = np.random.default_rng(seed) + self.source_datapipes = source_datapipes + self.weights = weights + if weights is None: + self.weights = [1 / len(self.source_datapipes)] * len( + self.source_datapipes) + else: + self.weights = [weight / sum(weights) for weight in weights] + self.iters = None + + def __iter__(self): + weights = copy.deepcopy(self.weights) + exhausted = len(self.source_datapipes) * [False] + if self.iters is None: + self.iters = [(i, iter(d)) + for i, d in enumerate(self.source_datapipes)] + while True: + # TODO(Mddct): rng + index_iter = self.rng.choice(self.iters, p=weights) + i, ite = index_iter + try: + elem = next(ite) + yield elem + except StopIteration: + weights[i] = 0. + exhausted[i] = True + if all(exhausted): + return + weights = [weight / sum(weights) for weight in weights] + + +class TextLineDataPipe(IterDataPipe): + """ Streamming Text line + """ + + def __init__(self, filenames, mode='r'): + super().__init__() + _dp = datapipes.iter.FileLister(filenames) + _dp = datapipes.iter.FileOpener(_dp, mode=mode) + self.dp = _dp + + def __iter__(self): + for fname, stream in self.dp: + for line in stream: + line = line.strip('\n') + yield {"file_name": fname, "line": line} + stream.close() + + +@functional_datapipe("tar_file_and_group") +class TarsDataPipe(IterDataPipe): + """ Decode wenet's tar , yield {'txt': "...", "raw": "..."} + """ + + def __init__(self, dataset: IterDataPipe) -> None: + super().__init__() + self.dp = dataset + + def __iter__(self): + from wenet.dataset.processor import AUDIO_FORMAT_SETS + for sample in self.dp: + assert 'file_name' in sample + assert 'line' in sample + assert 'stream' in sample + try: + with tarfile.open(fileobj=sample['stream'], + mode="r:*") as stream: + prev_prefix = None + example = { + 'file_name': sample['file_name'], + 'tar_file_name': sample['line'] + } + valid = True + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if prev_prefix is not None and prefix != prev_prefix: + example['key'] = prev_prefix + if valid: + yield example + example = { + 'file_name': sample['file_name'], + 'tar_file_name': sample['line'] + } + valid = True + with stream.extractfile(tarinfo) as file_obj: + try: + if postfix == 'txt': + example['txt'] = file_obj.read().decode( + 'utf8').strip() + elif postfix in AUDIO_FORMAT_SETS: + example['wav'] = file_obj.read() + else: + example[postfix] = file_obj.read() + except Exception as ex: + valid = False + logging.warning( + 'error to parse {}'.format(name)) + prev_prefix = prefix + if prev_prefix is not None: + example['key'] = prev_prefix + yield example + except Exception as ex: + msg = 'In tar_file_and_group: {} when processing {}'.format( + ex, sample['line']) + logging.warning(msg) + finally: + if 'process' in sample: + sample['process'].communicate() + sample['stream'].close() + + +class WenetRawDatasetSource(IterDataPipe): + + def __init__(self, + filenames: str, + prefetch: int = 500, + partition: bool = True, + shuffle: bool = False, + shuffle_size: int = 10000, + cycle: int = 1) -> None: + super().__init__() + self.dp = TextLineDataPipe(filenames) + if shuffle: + self.dp = self.dp.shuffle(buffer_size=shuffle_size) + self.dp = self.dp.repeat(cycle).prefetch(prefetch) + self.dp = self.dp.shard(partition) + + def __iter__(self): + for d in self.dp: + yield d + + +class WenetTarShardDatasetSource(IterDataPipe): + + def __init__(self, + filenames: str, + prefetch: int = 500, + partition: bool = True, + shuffle: bool = False, + shuffle_size: int = 10000, + cycle: int = 1) -> None: + super().__init__() + self.dp = TextLineDataPipe(filenames) + if shuffle: + self.dp = self.dp.shuffle(buffer_size=shuffle_size) + self.dp = self.dp.repeat(cycle) + self.dp = self.dp.shard(partition).map_ignore_error( + parse_url).tar_file_and_group().prefetch(prefetch) + + def __iter__(self): + for d in self.dp: + yield d diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fed31991e62232a0da741614129f66cf2d369f0f --- /dev/null +++ b/wenet/dataset/dataset.py @@ -0,0 +1,234 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset + +import wenet.dataset.deprecated.processor as processor +from wenet.text.base_tokenizer import BaseTokenizer +from wenet.utils.file_utils import read_lists + + +class Processor(IterableDataset): + + def __init__(self, source, f, *args, **kw): + assert callable(f) + self.source = source + self.f = f + self.args = args + self.kw = kw + + def set_epoch(self, epoch): + self.source.set_epoch(epoch) + + def __iter__(self): + """ Return an iterator over the source dataset processed by the + given processor. + """ + assert self.source is not None + assert callable(self.f) + return self.f(iter(self.source), *self.args, **self.kw) + + def apply(self, f): + assert callable(f) + return Processor(self, f, *self.args, **self.kw) + + +class DistributedSampler: + + def __init__(self, shuffle=True, partition=True, split_num=1): + self.epoch = -1 + self.update() + self.shuffle = shuffle + self.partition = partition + self.split_num = split_num + + def update(self): + assert dist.is_available() + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = 0 + self.world_size = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + self.worker_id = 0 + self.num_workers = 1 + else: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + return dict(rank=self.rank, + world_size=self.world_size, + worker_id=self.worker_id, + num_workers=self.num_workers) + + def set_epoch(self, epoch): + self.epoch = epoch + + def split_data(self, total_num): + data = list(range(total_num)) + sub_epoch = self.epoch + 1 + full_epoch = sub_epoch // self.split_num + num_per_sub_epochs = total_num // self.split_num + random.Random(full_epoch).shuffle(data) + + split_index = sub_epoch - full_epoch * self.split_num + begin = split_index * num_per_sub_epochs + end = (begin + num_per_sub_epochs + if (split_index + 1) < self.split_num else + total_num) + + # print(f'begin: {begin}, end: {end}, world_size: {self.world_size}') + return data[begin:end] + + def sample(self, data, split_num=1): + """ Sample data according to rank/world_size/num_workers + + Args: + data(List): input data list + + Returns: + List: data list after sample + """ + if self.split_num == 1: + data = list(range(len(data))) + else: + data = self.split_data(len(data)) + # TODO(Binbin Zhang): fix this + # We can not handle uneven data for CV on DDP, so we don't + # sample data by rank, that means every GPU gets the same + # and all the CV data + if self.partition: + if self.shuffle: + random.Random(self.epoch).shuffle(data) + data = data[self.rank::self.world_size] + # print(f'num dataset: {len(data)}') + data = data[self.worker_id::self.num_workers] + self.epoch += 1 + return data + + +class DataList(IterableDataset): + + def __init__(self, lists, shuffle=True, partition=True, split_num=1): + self.lists = lists + self.sampler = DistributedSampler(shuffle, partition, split_num) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def __iter__(self): + sampler_info = self.sampler.update() + indexes = self.sampler.sample(self.lists) + for index in indexes: + # yield dict(src=src) + data = dict(src=self.lists[index]) + data.update(sampler_info) + yield data + + +def Dataset(data_type, + data_list_file, + tokenizer: BaseTokenizer, + conf, + partition=True): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + bpe_model(str): model for english bpe part + partition(bool): whether to do data partition in terms of rank + """ + assert data_type in ['raw', 'shard', 'shard_full_data'] + lists = read_lists(data_list_file) + shuffle = conf.get('shuffle', True) + split_num = conf.get('split_num', 1) + dataset = DataList(lists, shuffle=shuffle, partition=partition, split_num=split_num) + if data_type == 'shard': + dataset = Processor(dataset, processor.url_opener) + dataset = Processor(dataset, processor.tar_file_and_group) + elif data_type == 'shard_full_data': + dataset = Processor(dataset, processor.url_opener) + dataset = Processor(dataset, processor.tar_file_and_group_full_data) + else: + dataset = Processor(dataset, processor.parse_raw) + + speaker_conf = conf.get('speaker_conf', None) + if speaker_conf is not None: + dataset = Processor(dataset, processor.parse_speaker, **speaker_conf) + + if conf.get('eod_id', None) is not None: + tokenizer.eod_id = conf['eod_id'] + # prompt dict + from gxl_ai_utils.utils import utils_file + global_prompt_dict = utils_file.load_dict_from_yaml('conf/prompt_stage4.yaml') + dataset = Processor(dataset, processor.tokenize, tokenizer, + global_prompt_dict=global_prompt_dict) + filter_conf = conf.get('filter_conf', {}) + dataset = Processor(dataset, processor.filter, **filter_conf) + + resample_conf = conf.get('resample_conf', {}) + dataset = Processor(dataset, processor.resample, **resample_conf) + + speed_perturb = conf.get('speed_perturb', False) + if speed_perturb: + dataset = Processor(dataset, processor.speed_perturb) + + feats_type = conf.get('feats_type', 'fbank') + assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram'] + if feats_type == 'fbank': + fbank_conf = conf.get('fbank_conf', {}) + dataset = Processor(dataset, processor.compute_fbank, **fbank_conf) + elif feats_type == 'mfcc': + mfcc_conf = conf.get('mfcc_conf', {}) + dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf) + elif feats_type == 'log_mel_spectrogram': + log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {}) + dataset = Processor(dataset, processor.compute_log_mel_spectrogram, + **log_mel_spectrogram_conf) + + spec_aug = conf.get('spec_aug', True) + spec_sub = conf.get('spec_sub', False) + spec_trim = conf.get('spec_trim', False) + if spec_aug: + spec_aug_conf = conf.get('spec_aug_conf', {}) + dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) + if spec_sub: + spec_sub_conf = conf.get('spec_sub_conf', {}) + dataset = Processor(dataset, processor.spec_sub, **spec_sub_conf) + if spec_trim: + spec_trim_conf = conf.get('spec_trim_conf', {}) + dataset = Processor(dataset, processor.spec_trim, **spec_trim_conf) + + if shuffle: + shuffle_conf = conf.get('shuffle_conf', {}) + dataset = Processor(dataset, processor.shuffle, **shuffle_conf) + + sort = conf.get('sort', True) + if sort: + sort_conf = conf.get('sort_conf', {}) + dataset = Processor(dataset, processor.sort, **sort_conf) + + batch_conf = conf.get('batch_conf', {}) + dataset = Processor(dataset, processor.batch, **batch_conf) + dataset = Processor(dataset, processor.padding) + return dataset diff --git a/wenet/dataset/deprecated/dataset.py b/wenet/dataset/deprecated/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9ce51612cbe6968cb5879be13404cc1115e04710 --- /dev/null +++ b/wenet/dataset/deprecated/dataset.py @@ -0,0 +1,202 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset + +import wenet.dataset.deprecated.processor as processor +from wenet.text.base_tokenizer import BaseTokenizer +from wenet.utils.file_utils import read_lists + + +class Processor(IterableDataset): + + def __init__(self, source, f, *args, **kw): + assert callable(f) + self.source = source + self.f = f + self.args = args + self.kw = kw + + def set_epoch(self, epoch): + self.source.set_epoch(epoch) + + def __iter__(self): + """ Return an iterator over the source dataset processed by the + given processor. + """ + assert self.source is not None + assert callable(self.f) + return self.f(iter(self.source), *self.args, **self.kw) + + def apply(self, f): + assert callable(f) + return Processor(self, f, *self.args, **self.kw) + + +class DistributedSampler: + + def __init__(self, shuffle=True, partition=True): + self.epoch = -1 + self.update() + self.shuffle = shuffle + self.partition = partition + + def update(self): + assert dist.is_available() + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = 0 + self.world_size = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + self.worker_id = 0 + self.num_workers = 1 + else: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + return dict(rank=self.rank, + world_size=self.world_size, + worker_id=self.worker_id, + num_workers=self.num_workers) + + def set_epoch(self, epoch): + self.epoch = epoch + + def sample(self, data): + """ Sample data according to rank/world_size/num_workers + + Args: + data(List): input data list + + Returns: + List: data list after sample + """ + data = list(range(len(data))) + # TODO(Binbin Zhang): fix this + # We can not handle uneven data for CV on DDP, so we don't + # sample data by rank, that means every GPU gets the same + # and all the CV data + if self.partition: + if self.shuffle: + random.Random(self.epoch).shuffle(data) + data = data[self.rank::self.world_size] + data = data[self.worker_id::self.num_workers] + return data + + +class DataList(IterableDataset): + + def __init__(self, lists, shuffle=True, partition=True): + self.lists = lists + self.sampler = DistributedSampler(shuffle, partition) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def __iter__(self): + sampler_info = self.sampler.update() + indexes = self.sampler.sample(self.lists) + for index in indexes: + # yield dict(src=src) + data = dict(src=self.lists[index]) + data.update(sampler_info) + yield data + + +def Dataset(data_type, + data_list_file, + tokenizer: BaseTokenizer, + conf, + partition=True): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + bpe_model(str): model for english bpe part + partition(bool): whether to do data partition in terms of rank + """ + assert data_type in ['raw', 'shard'] + lists = read_lists(data_list_file) + shuffle = conf.get('shuffle', True) + dataset = DataList(lists, shuffle=shuffle, partition=partition) + if data_type == 'shard': + dataset = Processor(dataset, processor.url_opener) + dataset = Processor(dataset, processor.tar_file_and_group) + else: + dataset = Processor(dataset, processor.parse_raw) + + speaker_conf = conf.get('speaker_conf', None) + if speaker_conf is not None: + dataset = Processor(dataset, processor.parse_speaker, **speaker_conf) + + dataset = Processor(dataset, processor.tokenize, tokenizer) + filter_conf = conf.get('filter_conf', {}) + dataset = Processor(dataset, processor.filter, **filter_conf) + + resample_conf = conf.get('resample_conf', {}) + dataset = Processor(dataset, processor.resample, **resample_conf) + + speed_perturb = conf.get('speed_perturb', False) + if speed_perturb: + dataset = Processor(dataset, processor.speed_perturb) + + feats_type = conf.get('feats_type', 'fbank') + assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram'] + if feats_type == 'fbank': + fbank_conf = conf.get('fbank_conf', {}) + dataset = Processor(dataset, processor.compute_fbank, **fbank_conf) + elif feats_type == 'mfcc': + mfcc_conf = conf.get('mfcc_conf', {}) + dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf) + elif feats_type == 'log_mel_spectrogram': + log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {}) + dataset = Processor(dataset, processor.compute_log_mel_spectrogram, + **log_mel_spectrogram_conf) + + spec_aug = conf.get('spec_aug', True) + spec_sub = conf.get('spec_sub', False) + spec_trim = conf.get('spec_trim', False) + if spec_aug: + spec_aug_conf = conf.get('spec_aug_conf', {}) + dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) + if spec_sub: + spec_sub_conf = conf.get('spec_sub_conf', {}) + dataset = Processor(dataset, processor.spec_sub, **spec_sub_conf) + if spec_trim: + spec_trim_conf = conf.get('spec_trim_conf', {}) + dataset = Processor(dataset, processor.spec_trim, **spec_trim_conf) + + if shuffle: + shuffle_conf = conf.get('shuffle_conf', {}) + dataset = Processor(dataset, processor.shuffle, **shuffle_conf) + + sort = conf.get('sort', True) + if sort: + sort_conf = conf.get('sort_conf', {}) + dataset = Processor(dataset, processor.sort, **sort_conf) + + batch_conf = conf.get('batch_conf', {}) + dataset = Processor(dataset, processor.batch, **batch_conf) + dataset = Processor(dataset, processor.padding) + return dataset diff --git a/wenet/dataset/deprecated/processor.py b/wenet/dataset/deprecated/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a0c671db14abe35cdae4e590c024c46fc203b2 --- /dev/null +++ b/wenet/dataset/deprecated/processor.py @@ -0,0 +1,1023 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import librosa +import logging +import json +import random +import tarfile +from subprocess import PIPE, Popen +from urllib.parse import urlparse + +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi +import torch.nn.functional as F +from gxl_ai_utils.utils import utils_file +from torch.nn.utils.rnn import pad_sequence +from wenet.text.base_tokenizer import BaseTokenizer + +# torchaudio.utils.sox_utils.set_buffer_size(16500) +torchaudio.set_audio_backend("soundfile") + +AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + + +def url_opener(data): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + # TODO(Binbin Zhang): support HTTP + url = sample['src'] + try: + pr = urlparse(url) + # local file + if pr.scheme == '' or pr.scheme == 'file': + stream = open(url, 'rb') + # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP + else: + cmd = f'wget -q -O - {url}' + process = Popen(cmd, shell=True, stdout=PIPE) + sample.update(process=process) + stream = process.stdout + sample.update(stream=stream) + yield sample + except Exception as ex: + logging.warning('Failed to open {}'.format(url)) + + +def tar_file_and_group(data): + """ Expand a stream of open tar files into a stream of tar file contents. + And groups the file with same prefix + + Args: + data: Iterable[{src, stream}] + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + for sample in data: + assert 'stream' in sample + stream = None + try: + stream = tarfile.open(fileobj=sample['stream'], mode="r:*") + prev_prefix = None + example = {} + valid = True + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if prev_prefix is not None and prefix != prev_prefix: + example['key'] = prev_prefix + if valid: + yield example + example = {} + valid = True + with stream.extractfile(tarinfo) as file_obj: + try: + if postfix == 'txt': + example['txt'] = file_obj.read().decode( + 'utf8').strip() + elif postfix in AUDIO_FORMAT_SETS: + waveform, sample_rate = torchaudio.load(file_obj) + example['wav'] = waveform + example['sample_rate'] = sample_rate + else: + example[postfix] = file_obj.read() + except Exception as ex: + valid = False + logging.warning('error to parse {}'.format(name)) + prev_prefix = prefix + if prev_prefix is not None: + example['key'] = prev_prefix + yield example + except Exception as ex: + logging.warning( + 'In tar_file_and_group: {} when processing {}'.format( + ex, sample['src'])) + finally: + if stream is not None: + stream.close() + if 'process' in sample: + sample['process'].communicate() + sample['stream'].close() + + +def tar_file_and_group_full_data(data): + """ Expand a stream of open tar files into a stream of tar file contents. + And groups the file with same prefix + + Args: + data: Iterable[{src, stream}] + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + for sample in data: + assert 'stream' in sample + stream = None + try: + stream = tarfile.open(fileobj=sample['stream'], mode="r:*") + prev_prefix = None + example = {} + valid = True + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if prev_prefix is not None and prefix != prev_prefix: + example['key'] = prev_prefix + if valid: + # assert 'txt' in example + if 'txt' not in example: + example['txt'] = '' + yield example + example = {} + valid = True + with stream.extractfile(tarinfo) as file_obj: + try: + if postfix == 'txt': + example['txt'] = file_obj.read().decode( + 'utf8').strip() + elif postfix == 'lang': + example['lang'] = file_obj.read().decode( + 'utf8').strip() + elif postfix == 'speaker': + try: + example['speaker'] = file_obj.read().decode( + 'utf8').strip() + except Exception as ex: + example['speaker'] = "none" + elif postfix == 'emotion': + example['emotion'] = file_obj.read().decode( + 'utf8').strip() + elif postfix == 'gender': + example['gender'] = file_obj.read().decode( + 'utf8').strip() + elif postfix == 'task': + example['task'] = file_obj.read().decode( + 'utf8').strip() + elif postfix == 'speech_token': + example['speech_token'] = file_obj.read() + elif postfix == 'duration': + duration_str = file_obj.read().decode( + 'utf8').strip() + try: + duration_float = float(duration_str) + example['duration'] = duration_float + except Exception as ex: + logging.warning(f'error to parse duration {duration_str}') + example['duration'] = 0 + + elif postfix in AUDIO_FORMAT_SETS: + waveform, sample_rate = torchaudio.load(file_obj) + # 检查音频的维度 + num_channels = waveform.shape[0] + # 如果音频是多通道的,则进行通道平均 + if num_channels > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + example['wav'] = waveform + example['sample_rate'] = sample_rate + else: + example[postfix] = file_obj.read() + except Exception as ex: + valid = False + # logging.warning('error to parse {}'.format(name)) + prev_prefix = prefix + if prev_prefix is not None: + example['key'] = prev_prefix + if 'txt' in example: + yield example + + except Exception as ex: + logging.warning( + 'In tar_file_and_group: {} when processing {}'.format( + ex, sample['src'])) + finally: + if stream is not None: + stream.close() + if 'process' in sample: + sample['process'].communicate() + sample['stream'].close() + + +def parse_raw(data): + """ Parse key/wav/txt from json line + + Args: + data: Iterable[str], str is a json line has key/wav/txt + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + for sample in data: + assert 'src' in sample + json_line = sample['src'] + obj = json.loads(json_line) + assert 'key' in obj + assert 'wav' in obj + assert 'txt' in obj + key = obj['key'] + wav_file = obj['wav'] + txt = obj['txt'] + try: + if 'start' in obj: + assert 'end' in obj + sample_rate = torchaudio.info(wav_file).sample_rate + start_frame = int(obj['start'] * sample_rate) + end_frame = int(obj['end'] * sample_rate) + waveform, _ = torchaudio.load(filepath=wav_file, + num_frames=end_frame - + start_frame, + frame_offset=start_frame) + else: + waveform, sample_rate = torchaudio.load(wav_file) + # 检查音频的维度 + num_channels = waveform.shape[0] + # 如果音频是多通道的,则进行通道平均 + if num_channels > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + example = copy.deepcopy(obj) # copy and keep all the fields + example['wav'] = waveform # overwrite wav + example['sample_rate'] = sample_rate + yield example + except Exception as ex: + logging.warning('Failed to read {}'.format(wav_file)) + + +def parse_speaker(data, speaker_table_path): + speaker_dict = {} + with open(speaker_table_path, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + speaker_dict[arr[0]] = int(arr[1]) + for sample in data: + assert 'speaker' in sample + speaker = sample['speaker'] + sample['speaker'] = speaker_dict.get(speaker, 0) + yield sample + + +def filter(data, + max_length=1200, + min_length=10, + token_max_length=250, + token_min_length=1, + min_output_input_ratio=0.00005, + max_output_input_ratio=1, + filter_no_extra_info: bool = False, + max_seq_len=1000): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + try: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'label' in sample + except: + continue + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 + + # filter for shard_in_common + if filter_no_extra_info: + if 'lang' not in sample: + continue + if 'task' not in sample: + continue + + if num_frames < min_length: + continue + + # if "output_type" in sample and sample["output_type"] == "speech2text_token": + # max_length = int(max_length / 2) + # if "output_type" in sample and sample["output_type"] == "text2token": + # max_length = int(max_length / 1.5) + if num_frames > max_length: + # continue + if 'task' in sample and sample['task'] == '': + # utils_file.logging_limit_print('进行了随机剪裁') + # 随机选择一个起始点进行裁剪 + start_frame = random.randint(0, int(num_frames - max_length)) + end_frame = start_frame + max_length + sample['wav'] = sample['wav'][:, int(start_frame / 100 * sample['sample_rate']): int( + end_frame / 100 * sample['sample_rate'])] + # print('sample[', sample['wav'].shape) + else: + continue + if len(sample['label']) < token_min_length: + continue + if len(sample['label']) > token_max_length: + continue + # if num_frames != 0: + # if len(sample['label']) / num_frames < min_output_input_ratio: + # continue + # if len(sample['label']) / num_frames > max_output_input_ratio: + # continue + + if sample["output_type"] == "speech2text_token": + seq_len = len(sample['prompt']) + num_frames / 8 + len(sample['label']) + len(sample['speech_token']) + elif sample["output_type"] == "text2token": + seq_len = len(sample['prompt']) + len(sample['label']) + len(sample['speech_token']) + else: + seq_len = len(sample['prompt']) + num_frames / 8 + len(sample['label']) + utils_file.logging_limit_print(f'seqlen: {seq_len}, output_type:{sample["output_type"]},len(sample["prompt"]):{len(sample["prompt"])},num_frames / 8:{num_frames / 8},len(sample["label"]):{len(sample["label"])},len(sample["speech_token"]):{len(sample["speech_token"])} ') + if max_seq_len > 0 and max_seq_len < seq_len: + utils_file.logging_limit_print(f"seqlen: {seq_len} 超过了最大长度:{max_seq_len},contiune") + continue + yield sample + + +def resample(data, resample_rate=16000): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + if sample_rate != resample_rate: + sample['sample_rate'] = resample_rate + sample['wav'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + yield sample + + +def speed_perturb(data, speeds=None): + """ Apply speed perturb to the data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + speeds(List[float]): optional speed + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + if speeds is None: + speeds = [0.9, 1.0, 1.1] + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + speed = random.choice(speeds) + if speed != 1.0: + wav, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, sample_rate, + [['speed', str(speed)], ['rate', str(sample_rate)]]) + sample['wav'] = wav + + yield sample + + +def compute_fbank(data, + num_mel_bins=23, + frame_length=25, + frame_shift=10, + dither=0.0): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sample_frequency=sample_rate) + sample['feat'] = mat + yield sample + + +def compute_mfcc(data, + num_mel_bins=23, + frame_length=25, + frame_shift=10, + dither=0.0, + num_ceps=40, + high_freq=0.0, + low_freq=20.0): + """ Extract mfcc + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.mfcc(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + num_ceps=num_ceps, + high_freq=high_freq, + low_freq=low_freq, + sample_frequency=sample_rate) + sample['feat'] = mat + yield sample + + +def compute_log_mel_spectrogram(data, + n_fft=400, + hop_length=160, + num_mel_bins=80, + padding=0): + """ Extract log mel spectrogram, modified from openai-whisper, see: + - https://github.com/openai/whisper/blob/main/whisper/audio.py + - https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040 + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,) + # print(f'wavform shape: {waveform.shape}') + if padding > 0: + waveform = F.pad(waveform, (0, padding)) + window = torch.hann_window(n_fft) + stft = torch.stft(waveform, + n_fft, + hop_length, + window=window, + return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = torch.from_numpy( + librosa.filters.mel(sr=sample_rate, + n_fft=n_fft, + n_mels=num_mel_bins)) + mel_spec = filters @ magnitudes + + # NOTE(xcsong): https://github.com/openai/whisper/discussions/269 + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + sample['feat'] = log_spec.transpose(0, 1) + yield sample + + +import re + + +def process_text(text): + # 1. 删除汉字左右两侧的空格 + text = re.sub(r'\s*([\u4e00-\u9fff])\s*', r'\1', text) + # 2. 将英文转成小写 + text = text.lower() + # 3. 删除 < 和 > 符号两侧的空格 + text = re.sub(r'\s*<\s*', '<', text) + text = re.sub(r'\s*>\s*', '>', text) + return text + + +global_style_dict = { + "朗读": "新闻科普", + "科普百科": "新闻科普", + "悬疑恐怖": "恐怖故事", + "童话故事": "童话故事", + "客服": "客服", + "诗歌": "诗歌散文", + "散文": "诗歌散文", + "武侠评书": "有声书", + "小说": "有声书", + "历史": "有声书", + "科幻": "有声书", + "对话": "日常口语", + "口语": "日常口语", + "幽默": "其他", + "其他": "其他", +} + + +def replace_keys_in_brackets(input_str, key_value_dict): + for key, value in key_value_dict.items(): + # 构造匹配 形式的正则表达式模式 + pattern = re.compile(r'<{}>'.format(key)) + input_str = pattern.sub(f"<{value}>", input_str) + return input_str + + +def tokenize(data, tokenizer: BaseTokenizer, global_prompt_dict=None): + """ Decode text to chars or BPE + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + for sample in data: + try: + assert 'txt' in sample + except: + print(f'tokenize: {sample}') + exit() + if 'task' in sample: + task_name = sample['task'] + # if "" in task_name: + # txt = sample['txt'].replace("", "").replace("", "").replace("", "") + if "