From c8e96d732ba4a941da9c72cf23d01d058564615c Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 2 Oct 2021 09:44:14 +0000 Subject: [PATCH] bool logical, sum and multiply op; ctc grad norm; support old and new pd api --- deepspeech/models/ds2/conv.py | 14 +++++++-- deepspeech/models/ds2/rnn.py | 6 ++-- deepspeech/models/u2/u2.py | 12 ++++++-- deepspeech/models/u2_st.py | 8 +++-- deepspeech/modules/decoder.py | 8 +++-- deepspeech/modules/encoder.py | 3 +- deepspeech/modules/loss.py | 53 ++++++++++++++++++++------------ deepspeech/modules/mask.py | 16 +++++++--- deepspeech/utils/tensor_utils.py | 10 ++++-- tests/mask_test.py | 4 +-- 10 files changed, 93 insertions(+), 41 deletions(-) diff --git a/deepspeech/models/ds2/conv.py b/deepspeech/models/ds2/conv.py index 9548af0a2..069b7dd4b 100644 --- a/deepspeech/models/ds2/conv.py +++ b/deepspeech/models/ds2/conv.py @@ -41,6 +41,13 @@ def conv_output_size(I, F, P, S): return (I - F + 2 * P - S) // S +# receptive field calculator +# https://fomoro.com/research/article/receptive-field-calculator +# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters +# https://distill.pub/2019/computing-receptive-fields/ +# Rl-1 = Sl * Rl + (Kl - Sl) + + class ConvBn(nn.Layer): """Convolution layer with batch normalization. @@ -106,9 +113,10 @@ class ConvBn(nn.Layer): # reset padding part to 0 masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - # https://github.com/PaddlePaddle/Paddle/pull/29265 - # rhs will type promote to lhs - x = x * masks + # TODO(Hui Zhang): not support bool multiply + # masks = masks.type_as(x) + masks = masks.astype(x.dtype) + x = x.multiply(masks) return x, x_len diff --git a/deepspeech/models/ds2/rnn.py b/deepspeech/models/ds2/rnn.py index 3fc52a378..68a3e6e72 100644 --- a/deepspeech/models/ds2/rnn.py +++ b/deepspeech/models/ds2/rnn.py @@ -308,8 +308,8 @@ class RNNStack(nn.Layer): x, x_len = rnn(x, x_len) masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(-1) # [B, T, 1] - # https://github.com/PaddlePaddle/Paddle/pull/29265 - # rhs will type promote to lhs - x = x * masks + # TODO(Hui Zhang): not support bool multiply + masks = masks.astype(x.dtype) + x = x.multiply(masks) return x, x_len diff --git a/deepspeech/models/u2/u2.py b/deepspeech/models/u2/u2.py index 46bbd102f..e6cd7b5c8 100644 --- a/deepspeech/models/u2/u2.py +++ b/deepspeech/models/u2/u2.py @@ -164,7 +164,10 @@ class U2BaseModel(nn.Layer): encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_time = time.time() - start #logger.debug(f"encoder time: {encoder_time}") - encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] + #TODO(Hui Zhang): sum not support bool type + #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( + 1) #[B, 1, T] -> [B] # 2a. Attention-decoder branch loss_att = None @@ -319,7 +322,8 @@ class U2BaseModel(nn.Layer): # 2. Decoder forward step by step for i in range(1, maxlen + 1): # Stop if all batch and all beam produce eos - if end_flag.sum() == running_size: + # TODO(Hui Zhang): if end_flag.sum() == running_size: + if end_flag.cast(paddle.int64).sum() == running_size: break # 2.1 Forward decoder step @@ -405,7 +409,9 @@ class U2BaseModel(nn.Layer): speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) maxlen = encoder_out.shape[1] - encoder_out_lens = encoder_mask.squeeze(1).sum(1) + # (TODO Hui Zhang): bool no support reduce_sum + # encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py index 8f87f6daa..bf98423d4 100644 --- a/deepspeech/models/u2_st.py +++ b/deepspeech/models/u2_st.py @@ -165,7 +165,10 @@ class U2STBaseModel(nn.Layer): encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_time = time.time() - start #logger.debug(f"encoder time: {encoder_time}") - encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] + #TODO(Hui Zhang): sum not support bool type + #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( + 1) #[B, 1, T] -> [B] # 2a. ST-decoder branch start = time.time() @@ -362,7 +365,8 @@ class U2STBaseModel(nn.Layer): # 2. Decoder forward step by step for i in range(1, maxlen + 1): # Stop if all batch and all beam produce eos - if end_flag.sum() == running_size: + # TODO(Hui Zhang): if end_flag.sum() == running_size: + if end_flag.cast(paddle.int64).sum() == running_size: break # 2.1 Forward decoder step diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index 8ca72894a..1ae3ce371 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -124,7 +124,9 @@ class TransformerDecoder(nn.Layer): # m: (1, L, L) m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0) # tgt_mask: (B, L, L) - tgt_mask = tgt_mask & m + # TODO(Hui Zhang): not support & for tensor + # tgt_mask = tgt_mask & m + tgt_mask = tgt_mask.logical_and(m) x, _ = self.embed(tgt) for layer in self.decoders: @@ -135,7 +137,9 @@ class TransformerDecoder(nn.Layer): if self.use_output_layer: x = self.output_layer(x) - olens = tgt_mask.sum(1) + # TODO(Hui Zhang): reduce_sum not support bool type + # olens = tgt_mask.sum(1) + olens = tgt_mask.astype(paddle.int).sum(1) return x, olens def forward_one_step( diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index d4a8275c3..6ffb6465c 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -162,7 +162,8 @@ class BaseEncoder(nn.Layer): xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor masks = masks.astype(paddle.bool) - mask_pad = ~masks + #TODO(Hui Zhang): mask_pad = ~masks + mask_pad = masks.logical_not() chunk_masks = add_optional_chunk_mask( xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, self.static_chunk_size, diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index 1f33e5125..df5298ea3 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect +from functools import partial + import paddle from paddle import nn from paddle.nn import functional as F @@ -32,18 +35,19 @@ class CTCLoss(nn.Layer): # last token id as blank id self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.batch_average = batch_average + logger.info( f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}") + logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") - # instance for norm_by_times - # batch for norm_by_batchsize - # frame for norm_by_total_logits_len assert grad_norm_type in ('instance', 'batch', 'frame', None) self.norm_by_times = False self.norm_by_batchsize = False self.norm_by_total_logits_len = False - logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") - if grad_norm_type == 'instance': + if grad_norm_type is None: + # no grad norm + pass + elif grad_norm_type == 'instance': self.norm_by_times = True elif grad_norm_type == 'batch': self.norm_by_batchsize = True @@ -51,6 +55,22 @@ class CTCLoss(nn.Layer): self.norm_by_total_logits_len = True else: raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}") + self.kwargs = { + "norm_by_times": self.norm_by_times, + "norm_by_batchsize": self.norm_by_batchsize, + "norm_by_total_logits_len": self.norm_by_total_logits_len, + } + + # Derive only the args which the func has + try: + param = inspect.signature(self.loss.forward).parameters + except ValueError: + # Some function, e.g. built-in function, are failed + param = {} + _kwargs = {k: v for k, v in self.kwargs.items() if k in param} + _notin = {k: v for k, v in self.kwargs.items() if k not in param} + logger.info(f"{self.loss} kwargs:{_kwargs}, not support: {_notin}") + self.loss_fn = partial(self.loss.forward, **_kwargs) def forward(self, logits, ys_pad, hlens, ys_lens): """Compute CTC loss. @@ -70,14 +90,7 @@ class CTCLoss(nn.Layer): # logits: (B, L, D) -> (L, B, D) logits = logits.transpose([1, 0, 2]) ys_pad = ys_pad.astype(paddle.int32) - loss = self.loss( - logits, - ys_pad, - hlens, - ys_lens, - norm_by_times=self.norm_by_times, - norm_by_batchsize=self.norm_by_batchsize, - norm_by_total_logits_len=self.norm_by_total_logits_len) + loss = self.loss_fn(logits, ys_pad, hlens, ys_lens) if self.batch_average: # Batch-size average loss = loss / B @@ -118,8 +131,8 @@ class LabelSmoothingLoss(nn.Layer): size (int): the number of class padding_idx (int): padding class id which will be ignored for loss smoothing (float): smoothing rate (0.0 means the conventional CE) - normalize_length (bool): - True, normalize loss by sequence length; + normalize_length (bool): + True, normalize loss by sequence length; False, normalize loss by batch size. Defaults to False. """ @@ -136,7 +149,7 @@ class LabelSmoothingLoss(nn.Layer): The model outputs and data labels tensors are flatten to (batch*seqlen, class) shape and a mask is applied to the padding part which should not be calculated for loss. - + Args: x (paddle.Tensor): prediction (batch, seqlen, class) target (paddle.Tensor): @@ -152,7 +165,7 @@ class LabelSmoothingLoss(nn.Layer): # use zeros_like instead of torch.no_grad() for true_dist, # since no_grad() can not be exported by JIT true_dist = paddle.full_like(x, self.smoothing / (self.size - 1)) - ignore = (target == self.padding_idx) # (B,) + ignore = target == self.padding_idx # (B,) #TODO(Hui Zhang): target = target * (1 - ignore) # avoid -1 index target = target.masked_fill(ignore, 0) # avoid -1 index @@ -163,8 +176,10 @@ class LabelSmoothingLoss(nn.Layer): kl = self.criterion(F.log_softmax(x, axis=1), true_dist) - total = len(target) - int(ignore.sum()) + #TODO(Hui Zhang): sum not support bool type + #total = len(target) - int(ignore.sum()) + total = len(target) - int(ignore.type_as(target).sum()) denom = total if self.normalize_length else B - #TODO(Hui Zhang): numer = (kl * (1 - ignore)).sum() + #numer = (kl * (1 - ignore)).sum() numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum() return numer / denom diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index 6d46f5ba0..00f228a2b 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -69,7 +69,8 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]] """ - return ~make_pad_mask(lengths) + #return ~make_pad_mask(lengths) + return make_pad_mask(lengths).logical_not() def subsequent_mask(size: int) -> paddle.Tensor: @@ -91,7 +92,12 @@ def subsequent_mask(size: int) -> paddle.Tensor: [1, 1, 1]] """ ret = paddle.ones([size, size], dtype=paddle.bool) - return paddle.tril(ret) + #TODO(Hui Zhang): tril not support bool + #return paddle.tril(ret) + ret = ret.astype(paddle.float) + ret = paddle.tril(ret) + ret = ret.astype(paddle.bool) + return ret def subsequent_chunk_mask( @@ -180,13 +186,15 @@ def add_optional_chunk_mask(xs: paddle.Tensor, chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size, num_left_chunks) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) - chunk_masks = masks & chunk_masks # (B, L, L) + # chunk_masks = masks & chunk_masks # (B, L, L) + chunk_masks = masks.logical_and(chunk_masks) # (B, L, L) elif static_chunk_size > 0: num_left_chunks = num_decoding_left_chunks chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size, num_left_chunks) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) - chunk_masks = masks & chunk_masks # (B, L, L) + # chunk_masks = masks & chunk_masks # (B, L, L) + chunk_masks = masks.logical_and(chunk_masks) # (B, L, L) else: chunk_masks = masks return chunk_masks diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 0050794c7..0cc03b193 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -183,7 +183,13 @@ def th_accuracy(pad_outputs: paddle.Tensor, pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], pad_outputs.shape[1]).argmax(2) mask = pad_targets != ignore_label - numerator = paddle.sum( + #TODO(Hui Zhang): sum not support bool type + # numerator = paddle.sum( + # pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + numerator = ( pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) - denominator = paddle.sum(mask) + numerator = paddle.sum(numerator.type_as(pad_targets)) + #TODO(Hui Zhang): sum not support bool type + # denominator = paddle.sum(mask) + denominator = paddle.sum(mask.type_as(pad_targets)) return float(numerator) / float(denominator) diff --git a/tests/mask_test.py b/tests/mask_test.py index dbe8c4b09..f44aca8fc 100644 --- a/tests/mask_test.py +++ b/tests/mask_test.py @@ -37,13 +37,13 @@ class TestU2Model(unittest.TestCase): def test_make_non_pad_mask(self): res = make_non_pad_mask(self.lengths) - res2 = ~make_pad_mask(self.lengths) + res2 = make_pad_mask(self.lengths).logical_not() self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist()) self.assertSequenceEqual(res.numpy().tolist(), res2.numpy().tolist()) def test_make_pad_mask(self): res = make_pad_mask(self.lengths) - res1 = ~make_non_pad_mask(self.lengths) + res1 = make_non_pad_mask(self.lengths).logical_not() self.assertSequenceEqual(res.numpy().tolist(), self.pad_masks.tolist()) self.assertSequenceEqual(res.numpy().tolist(), res1.tolist())