From e5641ca43c2565b5f2e1e8fa714395774295264d Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 1 Apr 2021 08:27:23 +0000 Subject: [PATCH] fix bugs, refactor collator, add pad_sequence, fix ckpt bugs --- deepspeech/__init__.py | 93 +++++ deepspeech/io/collator.py | 69 ++-- deepspeech/io/utility.py | 82 ++++ deepspeech/models/deepspeech2.py | 14 +- deepspeech/models/u2.py | 638 +++++++++++++++++++++++++++++++ deepspeech/modules/conv.py | 5 +- deepspeech/modules/rnn.py | 2 +- deepspeech/training/trainer.py | 18 +- deepspeech/utils/checkpoint.py | 13 +- deepspeech/utils/utility.py | 1 + 10 files changed, 880 insertions(+), 55 deletions(-) create mode 100644 deepspeech/io/utility.py create mode 100644 deepspeech/models/u2.py diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index fa5ef0439..3889f3a73 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -13,6 +13,9 @@ # limitations under the License. import logging from typing import Union +from typing import Optional +from typing import List +from typing import Tuple from typing import Any import paddle @@ -83,6 +86,20 @@ if not hasattr(paddle.Tensor, 'numel'): paddle.Tensor.numel = paddle.numel +def new_full(x: paddle.Tensor, + size: Union[List[int], Tuple[int], paddle.Tensor], + fill_value: Union[float, int, bool, paddle.Tensor], + dtype=None): + return paddle.full(size, fill_value, dtype=x.dtype) + + +if not hasattr(paddle.Tensor, 'new_full'): + logger.warn( + "override new_full of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.new_full = new_full + + def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: return xs.equal(paddle.to_tensor(ys, dtype=xs.dtype, place=xs.place)) @@ -279,6 +296,7 @@ if not hasattr(paddle.nn, 'Module'): logger.warn("register user Module to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'Module', paddle.nn.Layer) +# maybe cause assert isinstance(sublayer, core.Layer) if not hasattr(paddle.nn, 'ModuleList'): logger.warn( "register user ModuleList to paddle.nn, remove this when fixed!") @@ -332,3 +350,78 @@ if not hasattr(paddle.nn, 'ConstantPad2d'): logger.warn( "register user ConstantPad2d to paddle.nn, remove this when fixed!") setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d) + +########### hcak paddle.jit ############# + +if not hasattr(paddle.jit, 'export'): + logger.warn("register user export to paddle.jit, remove this when fixed!") + setattr(paddle.jit, 'export', paddle.jit.to_static) + + +########### hcak paddle.nn.utils ############# +def pad_sequence(sequences: List[paddle.Tensor], + batch_first: bool=False, + padding_value: float=0.0) -> paddle.Tensor: + r"""Pad a list of variable length Tensors with ``padding_value`` + + ``pad_sequence`` stacks a list of Tensors along a new dimension, + and pads them to equal length. For example, if the input is list of + sequences with size ``L x *`` and if batch_first is False, and ``T x B x *`` + otherwise. + + `B` is batch size. It is equal to the number of elements in ``sequences``. + `T` is length of the longest sequence. + `L` is length of the sequence. + `*` is any number of trailing dimensions, including none. + + Example: + >>> from paddle.nn.utils.rnn import pad_sequence + >>> a = paddle.ones(25, 300) + >>> b = paddle.ones(22, 300) + >>> c = paddle.ones(15, 300) + >>> pad_sequence([a, b, c]).size() + paddle.Tensor([25, 3, 300]) + + Note: + This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` + where `T` is the length of the longest sequence. This function assumes + trailing dimensions and type of all the Tensors in sequences are same. + + Args: + sequences (list[Tensor]): list of variable length sequences. + batch_first (bool, optional): output will be in ``B x T x *`` if True, or in + ``T x B x *`` otherwise + padding_value (float, optional): value for padded elements. Default: 0. + + Returns: + Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. + Tensor of size ``B x T x *`` otherwise + """ + + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + max_size = sequences[0].size() + trailing_dims = max_size[1:] + max_len = max([s.size(0) for s in sequences]) + if batch_first: + out_dims = (len(sequences), max_len) + trailing_dims + else: + out_dims = (max_len, len(sequences)) + trailing_dims + + out_tensor = sequences[0].new_full(out_dims, padding_value) + for i, tensor in enumerate(sequences): + length = tensor.size(0) + # use index notation to prevent duplicate references to the tensor + if batch_first: + out_tensor[i, :length, ...] = tensor + else: + out_tensor[:length, i, ...] = tensor + + return out_tensor + + +if not hasattr(paddle.nn.utils, 'rnn.pad_sequence'): + logger.warn( + "register user rnn.pad_sequence to paddle.nn.utils, remove this when fixed!" + ) + setattr(paddle.nn.utils, 'rnn.pad_sequence', pad_sequence) diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 10f838fb2..cfe409911 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -16,15 +16,15 @@ import logging import numpy as np from collections import namedtuple +from deepspeech.io.utility import pad_sequence + logger = logging.getLogger(__name__) -__all__ = [ - "SpeechCollator", -] +__all__ = ["SpeechCollator"] class SpeechCollator(): - def __init__(self, padding_to=-1, is_training=True): + def __init__(self, is_training=True): """ Padding audio features with zeros to make them have the same shape (or a user-defined shape) within one bach. @@ -32,42 +32,51 @@ class SpeechCollator(): If ``padding_to`` is -1, the maximun shape in the batch will be used as the target shape for padding. Otherwise, `padding_to` will be the target shape (only refers to the second axis). + + if ``is_training`` is True, text is token ids else is raw string. """ - self._padding_to = padding_to self._is_training = is_training def __call__(self, batch): - new_batch = [] - # get target shape - max_length = max([audio.shape[1] for audio, _ in batch]) - if self._padding_to != -1: - if self._padding_to < max_length: - raise ValueError("If padding_to is not -1, it should be larger " - "than any instance's shape in the batch") - max_length = self._padding_to - max_text_length = max([len(text) for _, text in batch]) - # padding - padded_audios = [] + """batch examples + + Args: + batch ([List]): batch is (audio, text) + audio (np.ndarray) shape (D, T) + text (List[int] or str): shape (U,) + + Returns: + tuple(audio, text, audio_lens, text_lens): batched data. + audio : (B, Tmax, D) + text : (B, Umax) + audio_lens: (B) + text_lens: (B) + """ + audios = [] audio_lens = [] - texts, text_lens = [], [] + texts = [] + text_lens = [] for audio, text in batch: # audio - padded_audio = np.zeros([audio.shape[0], max_length]) - padded_audio[:, :audio.shape[1]] = audio - padded_audios.append(padded_audio) + audios.append(audio.T) # [T, D] audio_lens.append(audio.shape[1]) # text - padded_text = np.zeros([max_text_length]) + # for training, text is token ids + # else text is string, convert to unicode ord + tokens = [] if self._is_training: - padded_text[:len(text)] = text # token ids + tokens = text # token ids else: - padded_text[:len(text)] = [ord(t) - for t in text] # string, unicode ord - texts.append(padded_text) + assert isinstance(text, str) + tokens = [ord(t) for t in text] + tokens = tokens if isinstance(tokens, np.ndarray) else np.array( + tokens, dtype=np.int64) + texts.append(tokens) text_lens.append(len(text)) - padded_audios = np.array(padded_audios).astype('float32') - audio_lens = np.array(audio_lens).astype('int64') - texts = np.array(texts).astype('int32') - text_lens = np.array(text_lens).astype('int64') - return padded_audios, texts, audio_lens, text_lens + padded_audios = pad_sequence( + audios, padding_value=0.0).astype(np.float32) #[B, T, D] + padded_texts = pad_sequence(texts, padding_value=-1).astype(np.int32) + audio_lens = np.array(audio_lens).astype(np.int64) + text_lens = np.array(text_lens).astype(np.int64) + return padded_audios, padded_texts, audio_lens, text_lens diff --git a/deepspeech/io/utility.py b/deepspeech/io/utility.py new file mode 100644 index 000000000..46c9fbd29 --- /dev/null +++ b/deepspeech/io/utility.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +from collections import namedtuple +from typing import List + +logger = logging.getLogger(__name__) + +__all__ = ["pad_sequence"] + + +def pad_sequence(sequences: List[np.ndarray], + batch_first: bool=True, + padding_value: float=0.0) -> np.ndarray: + r"""Pad a list of variable length Tensors with ``padding_value`` + + ``pad_sequence`` stacks a list of Tensors along a new dimension, + and pads them to equal length. For example, if the input is list of + sequences with size ``L x *`` and if batch_first is False, and ``T x B x *`` + otherwise. + + `B` is batch size. It is equal to the number of elements in ``sequences``. + `T` is length of the longest sequence. + `L` is length of the sequence. + `*` is any number of trailing dimensions, including none. + + Example: + >>> a = np.ones([25, 300]) + >>> b = np.ones([22, 300]) + >>> c = np.ones([15, 300]) + >>> pad_sequence([a, b, c]).shape + [25, 3, 300] + + Note: + This function returns a np.ndarray of size ``T x B x *`` or ``B x T x *`` + where `T` is the length of the longest sequence. This function assumes + trailing dimensions and type of all the Tensors in sequences are same. + + Args: + sequences (list[np.ndarray]): list of variable length sequences. + batch_first (bool, optional): output will be in ``B x T x *`` if True, or in + ``T x B x *`` otherwise + padding_value (float, optional): value for padded elements. Default: 0. + + Returns: + np.ndarray of size ``T x B x *`` if :attr:`batch_first` is ``False``. + np.ndarray of size ``B x T x *`` otherwise + """ + + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + max_size = sequences[0].shape + trailing_dims = max_size[1:] + max_len = max([s.shape[0] for s in sequences]) + if batch_first: + out_dims = (len(sequences), max_len) + trailing_dims + else: + out_dims = (max_len, len(sequences)) + trailing_dims + + out_tensor = np.full(out_dims, padding_value, dtype=sequences[0].dtype) + for i, tensor in enumerate(sequences): + length = tensor.shape[0] + # use index notation to prevent duplicate references to the tensor + if batch_first: + out_tensor[i, :length, ...] = tensor + else: + out_tensor[:length, i, ...] = tensor + + return out_tensor diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index cab1e45e1..01edbbae6 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Deepspeech2 ASR Model""" import math import collections import numpy as np @@ -67,23 +67,19 @@ class CRNNEncoder(nn.Layer): return self.rnn_size * 2 def forward(self, audio, audio_len): - """ - audio: shape [B, D, T] - text: shape [B, T] - audio_len: shape [B] - text_len: shape [B] - """ """Compute Encoder outputs Args: - audio (Tensor): [B, D, T] - text (Tensor): [B, T] + audio (Tensor): [B, Tmax, D] + text (Tensor): [B, Umax] audio_len (Tensor): [B] text_len (Tensor): [B] Returns: x (Tensor): encoder outputs, [B, T, D] x_lens (Tensor): encoder length, [B] """ + # [B, T, D] -> [B, D, T] + audio = audio.transpose([0, 2, 1]) # [B, D, T] -> [B, C=1, D, T] x = audio.unsqueeze(1) x_lens = audio_len diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py new file mode 100644 index 000000000..9570aad1e --- /dev/null +++ b/deepspeech/models/u2.py @@ -0,0 +1,638 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""U2 ASR Model +Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition +(https://arxiv.org/pdf/2012.05481.pdf) +""" +import math +import collections +from collections import defaultdict +import numpy as np +import logging +from yacs.config import CfgNode +from typing import List, Optional, Tuple + +import paddle +from paddle import jit +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I +from paddle.nn.utils.rnn import pad_sequence + +from deepspeech.modules.cmvn import GlobalCMVN +from deepspeech.modules.encoder import ConformerEncoder +from deepspeech.modules.encoder import TransformerEncoder +from deepspeech.modules.ctc import CTCDecoder +from deepspeech.modules.decoder import TransformerDecoder +from deepspeech.modules.label_smoothing_loss import LabelSmoothingLoss +from deepspeech.modules.mask import make_pad_mask +from deepspeech.modules.mask import mask_finished_preds +from deepspeech.modules.mask import mask_finished_scores +from deepspeech.modules.mask import subsequent_mask + +from deepspeech.utils import checkpoint +from deepspeech.utils import layer_tools +from deepspeech.utils.cmvn import load_cmvn +from deepspeech.utils.utility import log_add +from deepspeech.utils.tensor_utils import IGNORE_ID +from deepspeech.utils.tensor_utils import add_sos_eos +from deepspeech.utils.tensor_utils import th_accuracy +from deepspeech.utils.ctc_utils import remove_duplicates_and_blank + +logger = logging.getLogger(__name__) + +__all__ = ['U2Model'] + + +class U2Model(nn.Module): + """CTC-Attention hybrid Encoder-Decoder model""" + + def __init__( + self, + vocab_size: int, + encoder: TransformerEncoder, + decoder: TransformerDecoder, + ctc: CTCDecoder, + ctc_weight: float=0.5, + ignore_id: int=IGNORE_ID, + lsm_weight: float=0.0, + length_normalized_loss: bool=False, ): + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + + self.encoder = encoder + self.decoder = decoder + self.ctc = ctc + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, ) + + def forward( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + ) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[ + paddle.Tensor]]: + """Frontend + Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == + text_lengths.shape[0]), (speech.shape, speech_lengths.shape, + text.shape, text_lengths.shape) + # 1. Encoder + encoder_out, encoder_mask = self.encoder(speech, speech_lengths) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + + # 2a. Attention-decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, + text, text_lengths) + else: + loss_att = None + + # 2b. CTC branch + if self.ctc_weight != 0.0: + loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, + text_lengths) + else: + loss_ctc = None + + if loss_ctc is None: + loss = loss_att + elif loss_att is None: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + return loss, loss_att, loss_ctc + + def _calc_att_loss( + self, + encoder_out: paddle.Tensor, + encoder_mask: paddle.Tensor, + ys_pad: paddle.Tensor, + ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, ) + return loss_att, acc_att + + def _forward_encoder( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + # Let's assume B = batch_size + # 1. Encoder + if simulate_streaming and decoding_chunk_size > 0: + encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk( + speech, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + else: + encoder_out, encoder_mask = self.encoder( + speech, + speech_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + return encoder_out, encoder_mask + + def recognize( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int=10, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, ) -> paddle.Tensor: + """ Apply beam search on attention decoder + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + paddle.Tensor: decoding result, (batch, max_result_len) + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + device = speech.device + batch_size = speech.shape[0] + + # Let's assume B = batch_size and N = beam_size + # 1. Encoder + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + encoder_dim = encoder_out.size(2) + running_size = batch_size * beam_size + encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( + running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) + encoder_mask = encoder_mask.unsqueeze(1).repeat( + 1, beam_size, 1, 1).view(running_size, 1, + maxlen) # (B*N, 1, max_len) + + hyps = torch.ones( + [running_size, 1], dtype=torch.long, + device=device).fill_(self.sos) # (B*N, 1) + scores = paddle.tensor( + [0.0] + [-float('inf')] * (beam_size - 1), dtype=torch.float) + scores = scores.to(device).repeat([batch_size]).unsqueeze(1).to( + device) # (B*N, 1) + end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device) + cache: Optional[List[paddle.Tensor]] = None + # 2. Decoder forward step by step + for i in range(1, maxlen + 1): + # Stop if all batch and all beam produce eos + if end_flag.sum() == running_size: + break + # 2.1 Forward decoder step + hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( + running_size, 1, 1).to(device) # (B*N, i, i) + # logp: (B*N, vocab) + logp, cache = self.decoder.forward_one_step( + encoder_out, encoder_mask, hyps, hyps_mask, cache) + # 2.2 First beam prune: select topk best prob at current time + top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) + top_k_logp = mask_finished_scores(top_k_logp, end_flag) + top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) + # 2.3 Seconde beam prune: select topk score with history + scores = scores + top_k_logp # (B*N, N), broadcast add + scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) + scores, offset_k_index = scores.topk(k=beam_size) # (B, N) + scores = scores.view(-1, 1) # (B*N, 1) + # 2.4. Compute base index in top_k_index, + # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), + # then find offset_k_index in top_k_index + base_k_index = torch.arange( + batch_size, + device=device).view(-1, 1).repeat([1, beam_size]) # (B, N) + base_k_index = base_k_index * beam_size * beam_size + best_k_index = base_k_index.view(-1) + offset_k_index.view( + -1) # (B*N) + + # 2.5 Update best hyps + best_k_pred = torch.index_select( + top_k_index.view(-1), dim=-1, index=best_k_index) # (B*N) + best_hyps_index = best_k_index // beam_size + last_best_k_hyps = torch.index_select( + hyps, dim=0, index=best_hyps_index) # (B*N, i) + hyps = torch.cat( + (last_best_k_hyps, best_k_pred.view(-1, 1)), + dim=1) # (B*N, i+1) + + # 2.6 Update end flag + end_flag = torch.eq(hyps[:, -1], self.eos).view(-1, 1) + + # 3. Select best of best + scores = scores.view(batch_size, beam_size) + # TODO: length normalization + best_index = torch.argmax(scores, dim=-1).long() + best_hyps_index = best_index + torch.arange( + batch_size, dtype=torch.long, device=device) * beam_size + best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index) + best_hyps = best_hyps[:, 1:] + return best_hyps + + def ctc_greedy_search( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, ) -> List[List[int]]: + """ Apply CTC greedy search + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[List[int]]: best path result + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + batch_size = speech.shape[0] + # Let's assume B = batch_size + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + mask = make_pad_mask(encoder_out_lens) # (B, maxlen) + topk_index = topk_index.masked_fill_(mask, self.eos) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + return hyps + + def _ctc_prefix_beam_search( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, + ) -> Tuple[List[List[int]], paddle.Tensor]: + """ CTC prefix beam search inner implementation + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[List[int]]: nbest results + paddle.Tensor: encoder output, (1, max_len, encoder_dim), + it will be used for rescoring in attention rescoring mode + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + batch_size = speech.shape[0] + # For CTC prefix beam search, we only support batch_size=1 + assert batch_size == 1 + # Let's assume B = batch_size and N = beam_size + # 1. Encoder forward and get CTC score + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + # 2.1 First beam prune: select topk best + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == 0: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + cur_hyps = next_hyps[:beam_size] + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] + return hyps, encoder_out + + def ctc_prefix_beam_search( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, ) -> List[int]: + """ Apply CTC prefix beam search + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[int]: CTC prefix beam search nbest results + """ + hyps, _ = self._ctc_prefix_beam_search( + speech, speech_lengths, beam_size, decoding_chunk_size, + num_decoding_left_chunks, simulate_streaming) + return hyps[0][0] + + def attention_rescoring( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + ctc_weight: float=0.0, + simulate_streaming: bool=False, ) -> List[int]: + """ Apply attention rescoring decoding, CTC prefix beam search + is applied first to get nbest, then we resoring the nbest on + attention decoder with corresponding encoder out + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[int]: Attention rescoring result + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + device = speech.device + batch_size = speech.shape[0] + # For attention rescoring we only support batch_size=1 + assert batch_size == 1 + # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size + hyps, encoder_out = self._ctc_prefix_beam_search( + speech, speech_lengths, beam_size, decoding_chunk_size, + num_decoding_left_chunks, simulate_streaming) + + assert len(hyps) == beam_size + hyps_pad = pad_sequence([ + paddle.tensor(hyp[0], device=device, dtype=torch.long) + for hyp in hyps + ], True, self.ignore_id) # (beam_size, max_hyps_len) + hyps_lens = paddle.tensor( + [len(hyp[0]) for hyp in hyps], device=device, + dtype=torch.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + encoder_out = encoder_out.repeat(beam_size, 1, 1) + encoder_mask = torch.ones( + beam_size, 1, encoder_out.size(1), dtype=torch.bool, device=device) + decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + decoder_out = decoder_out.cpu().numpy() + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + score += decoder_out[i][len(hyp[0])][self.eos] + # add ctc score + score += hyp[1] * ctc_weight + if score > best_score: + best_score = score + best_index = i + return hyps[best_index][0] + + @jit.export + def subsampling_rate(self) -> int: + """ Export interface for c++ call, return subsampling_rate of the + model + """ + return self.encoder.embed.subsampling_rate + + @jit.export + def right_context(self) -> int: + """ Export interface for c++ call, return right_context of the model + """ + return self.encoder.embed.right_context + + @jit.export + def sos_symbol(self) -> int: + """ Export interface for c++ call, return sos symbol id of the model + """ + return self.sos + + @jit.export + def eos_symbol(self) -> int: + """ Export interface for c++ call, return eos symbol id of the model + """ + return self.eos + + @jit.export + def forward_encoder_chunk( + self, + xs: paddle.Tensor, + offset: int, + required_cache_size: int, + subsampling_cache: Optional[paddle.Tensor]=None, + elayers_output_cache: Optional[List[paddle.Tensor]]=None, + conformer_cnn_cache: Optional[List[paddle.Tensor]]=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[ + paddle.Tensor]]: + """ Export interface for c++ call, give input chunk xs, and return + output from time 0 to current chunk. + Args: + xs (paddle.Tensor): chunk input + subsampling_cache (Optional[paddle.Tensor]): subsampling cache + elayers_output_cache (Optional[List[paddle.Tensor]]): + transformer/conformer encoder layers output cache + conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer + cnn cache + Returns: + paddle.Tensor: output, it ranges from time 0 to current chunk. + paddle.Tensor: subsampling cache + List[paddle.Tensor]: attention cache + List[paddle.Tensor]: conformer cnn cache + """ + return self.encoder.forward_chunk( + xs, offset, required_cache_size, subsampling_cache, + elayers_output_cache, conformer_cnn_cache) + + @jit.export + def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: + """ Export interface for c++ call, apply linear transform and log + softmax before ctc + Args: + xs (paddle.Tensor): encoder output + Returns: + paddle.Tensor: activation before ctc + """ + return self.ctc.log_softmax(xs) + + @jit.export + def forward_attention_decoder( + self, + hyps: paddle.Tensor, + hyps_lens: paddle.Tensor, + encoder_out: paddle.Tensor, ) -> paddle.Tensor: + """ Export interface for c++ call, forward decoder with multiple + hypothesis from ctc prefix beam search and one encoder output + Args: + hyps (paddle.Tensor): hyps from ctc prefix beam search, already + pad sos at the begining + hyps_lens (paddle.Tensor): length of each hyp in hyps + encoder_out (paddle.Tensor): corresponding encoder output + Returns: + paddle.Tensor: decoder output + """ + assert encoder_out.size(0) == 1 + num_hyps = hyps.size(0) + assert hyps_lens.size(0) == num_hyps + encoder_out = encoder_out.repeat(num_hyps, 1, 1) + encoder_mask = torch.ones( + num_hyps, + 1, + encoder_out.size(1), + dtype=torch.bool, + device=encoder_out.device) + decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps, + hyps_lens) # (num_hyps, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + return decoder_out + + +def init_asr_model(configs): + if configs['cmvn_file'] is not None: + mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) + global_cmvn = GlobalCMVN( + torch.from_numpy(mean).float(), torch.from_numpy(istd).float()) + else: + global_cmvn = None + + input_dim = configs['input_dim'] + vocab_size = configs['output_dim'] + + encoder_type = configs.get('encoder', 'conformer') + if encoder_type == 'conformer': + encoder = ConformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + else: + encoder = TransformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + + decoder = TransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + ctc = CTCDecoder(vocab_size, encoder.output_size()) + model = U2Model( + vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + **configs['model_conf'], ) + return model diff --git a/deepspeech/modules/conv.py b/deepspeech/modules/conv.py index 0c08624b5..d17f30522 100644 --- a/deepspeech/modules/conv.py +++ b/deepspeech/modules/conv.py @@ -145,7 +145,7 @@ class ConvStack(nn.Layer): act='brelu') out_channel = 32 - self.conv_stack = nn.Sequential([ + convs = [ ConvBn( num_channels_in=32, num_channels_out=out_channel, @@ -153,7 +153,8 @@ class ConvStack(nn.Layer): stride=(2, 1), padding=(10, 5), act='brelu') for i in range(num_stacks - 1) - ]) + ] + self.conv_stack = nn.LayerList(convs) # conv output feat_dim output_height = (feat_size - 1) // 2 + 1 diff --git a/deepspeech/modules/rnn.py b/deepspeech/modules/rnn.py index a6da4a11c..22dd13a44 100644 --- a/deepspeech/modules/rnn.py +++ b/deepspeech/modules/rnn.py @@ -298,7 +298,7 @@ class RNNStack(nn.Layer): share_weights=share_rnn_weights)) i_size = h_size * 2 - self.rnn_stacks = nn.Sequential(rnn_stacks) + self.rnn_stacks = nn.ModuleList(rnn_stacks) def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): """ diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 3a381b2b7..39bb1ccd0 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -128,14 +128,15 @@ class Trainer(): dist.init_parallel_env() @mp_tools.rank_zero_only - def save(self): + def save(self, infos=None): """Save checkpoint (model parameters and optimizer states). """ - infos = { - "step": self.iteration, - "epoch": self.epoch, - "lr": self.optimizer.get_lr(), - } + if infos is None: + infos = { + "step": self.iteration, + "epoch": self.epoch, + "lr": self.optimizer.get_lr(), + } checkpoint.save_parameters(self.checkpoint_dir, self.iteration, self.model, self.optimizer, infos) @@ -151,8 +152,9 @@ class Trainer(): self.optimizer, checkpoint_dir=self.checkpoint_dir, checkpoint_path=self.args.checkpoint_path) - self.iteration = infos["step"] - self.epoch = infos["epoch"] + if infos: + self.iteration = infos["step"] + self.epoch = infos["epoch"] def new_epoch(self): """Reset the train loader and increment ``epoch``. diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index d265358b2..c265e5929 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -36,11 +36,11 @@ def _load_latest_checkpoint(checkpoint_dir: str) -> int: Args: checkpoint_dir (str): the directory where checkpoint is saved. Returns: - int: the latest iteration number. + int: the latest iteration number. -1 for no checkpoint to load. """ checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") if not os.path.isfile(checkpoint_record): - return 0 + return -1 # Fetch the latest checkpoint index. with open(checkpoint_record, "rt") as handle: @@ -79,11 +79,15 @@ def load_parameters(model, Returns: configs (dict): epoch or step, lr and other meta info should be saved. """ + configs = {} + if checkpoint_path is not None: iteration = int(os.path.basename(checkpoint_path).split(":")[-1]) elif checkpoint_dir is not None: iteration = _load_latest_checkpoint(checkpoint_dir) - checkpoint_path = os.path.join(checkpoint_dir, "-{}".format(iteration)) + if iteration == -1: + return configs + checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) else: raise ValueError( "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" @@ -104,7 +108,6 @@ def load_parameters(model, rank, optimizer_path)) info_path = re.sub('.pdparams$', '.json', params_path) - configs = {} if os.path.exists(info_path): with open(info_path, 'r') as fin: configs = json.load(fin) @@ -128,7 +131,7 @@ def save_parameters(checkpoint_dir: str, Returns: None """ - checkpoint_path = os.path.join(checkpoint_dir, "-{}".format(iteration)) + checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) model_dict = model.state_dict() params_path = checkpoint_path + ".pdparams" diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index 8b45c75da..96f253b53 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -16,6 +16,7 @@ import math import numpy as np import distutils.util +from typing import List __all__ = ['print_arguments', 'add_arguments', "log_add"]