diff --git a/.notebook/dataloader.ipynb b/.notebook/dataloader.ipynb index e2b8b3a0a..3de8f64a9 100644 --- a/.notebook/dataloader.ipynb +++ b/.notebook/dataloader.ipynb @@ -338,7 +338,7 @@ } ], "source": [ - "for idx, (audio, text, audio_len, text_len) in enumerate(batch_reader()):\n", + "for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", " print('test', text)\n", " print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", " print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", @@ -386,4 +386,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/.notebook/train_test.ipynb b/.notebook/train_test.ipynb index b2e454395..67212e50a 100644 --- a/.notebook/train_test.ipynb +++ b/.notebook/train_test.ipynb @@ -249,7 +249,7 @@ } ], "source": [ - " for idx, (audio, text, audio_len, text_len) in enumerate(batch_reader()):\n", + " for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", " print('test', text)\n", " print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[0]))\n", " print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[-1]))\n", @@ -835,7 +835,7 @@ "\n", " return logits, probs, audio_len\n", "\n", - " def forward(self, audio, text, audio_len, text_len):\n", + " def forward(self, audio, audio_len, text, text_len):\n", " \"\"\"\n", " audio: shape [B, D, T]\n", " text: shape [B, T]\n", @@ -877,10 +877,10 @@ "metadata": {}, "outputs": [], "source": [ - "audio, text, audio_len, text_len = None, None, None, None\n", + "audio, audio_len, text, text_len = None, None, None, None\n", "\n", "for idx, inputs in enumerate(batch_reader):\n", - " audio, text, audio_len, text_len = inputs\n", + " audio, audio_len, text, text_len = inputs\n", "# print(idx)\n", "# print('a', audio.shape, audio.place)\n", "# print('t', text)\n", @@ -960,7 +960,7 @@ } ], "source": [ - "outputs = dp_model(audio, text, audio_len, text_len)\n", + "outputs = dp_model(audio, audio_len, text, text_len)\n", "logits, _, logits_len = outputs\n", "print('logits len', logits_len)\n", "loss = loss_fn.forward(logits, text, logits_len, text_len)\n", diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 3889f3a73..5a5a06b40 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -222,6 +222,31 @@ if not hasattr(paddle.Tensor, 'relu'): logger.warn("register user relu to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) + +def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor: + return x.astype(other.dtype) + + +if not hasattr(paddle.Tensor, 'type_as'): + logger.warn( + "register user type_as to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'type_as', type_as) + + +def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: + assert len(args) == 1 + if isinstace(args[0], str): # dtype + return x.astype(args[0]) + elif isinstance(args[0], paddle.Tensor): #Tensor + return x.astype(args[0].dtype) + else: # Device + return x + + +if not hasattr(paddle.Tensor, 'to'): + logger.warn("register user to to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'to', to) + ########### hcak paddle.nn.functional ############# diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index 33b83283f..50de94c3b 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -103,7 +103,7 @@ def tune(config, args): trans.append(''.join([chr(i) for i in ids])) return trans - audio, text, audio_len, text_len = infer_data + audio, audio_len, text, text_len = infer_data target_transcripts = ordid2token(text, text_len) num_ins += audio.shape[0] diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index cfe409911..c3a11f0f0 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -17,6 +17,7 @@ import numpy as np from collections import namedtuple from deepspeech.io.utility import pad_sequence +from deepspeech.utils.tensor_utils import IGNORE_ID logger = logging.getLogger(__name__) @@ -29,10 +30,6 @@ class SpeechCollator(): Padding audio features with zeros to make them have the same shape (or a user-defined shape) within one bach. - If ``padding_to`` is -1, the maximun shape in the batch will be used - as the target shape for padding. Otherwise, `padding_to` will be the - target shape (only refers to the second axis). - if ``is_training`` is True, text is token ids else is raw string. """ self._is_training = is_training @@ -48,8 +45,8 @@ class SpeechCollator(): Returns: tuple(audio, text, audio_lens, text_lens): batched data. audio : (B, Tmax, D) - text : (B, Umax) audio_lens: (B) + text : (B, Umax) text_lens: (B) """ audios = [] @@ -76,7 +73,9 @@ class SpeechCollator(): padded_audios = pad_sequence( audios, padding_value=0.0).astype(np.float32) #[B, T, D] - padded_texts = pad_sequence(texts, padding_value=-1).astype(np.int32) audio_lens = np.array(audio_lens).astype(np.int64) + # (TODO:Hui Zhang) ctc loss does not support int64 labels + padded_texts = pad_sequence( + texts, padding_value=IGNORE_ID).astype(np.int32) text_lens = np.array(text_lens).astype(np.int64) - return padded_audios, padded_texts, audio_lens, text_lens + return padded_audios, audio_lens, padded_texts, text_lens diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index cdf32cf37..56bd3bcf5 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -168,13 +168,13 @@ class DeepSpeech2Model(nn.Layer): dropout_rate=0.0, reduction=True) - def forward(self, audio, text, audio_len, text_len): + def forward(self, audio, audio_len, text, text_len): """Compute Model loss Args: audio (Tenosr): [B, T, D] - text (Tensor): [B, U] audio_len (Tensor): [B] + text (Tensor): [B, U] text_len (Tensor): [B] Returns: diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 9570aad1e..de4b7a08f 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -28,7 +28,12 @@ from paddle import jit from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I + from paddle.nn.utils.rnn import pad_sequence +from deepspeech.modules.mask import make_pad_mask +from deepspeech.modules.mask import mask_finished_preds +from deepspeech.modules.mask import mask_finished_scores +from deepspeech.modules.mask import subsequent_mask from deepspeech.modules.cmvn import GlobalCMVN from deepspeech.modules.encoder import ConformerEncoder @@ -36,10 +41,6 @@ from deepspeech.modules.encoder import TransformerEncoder from deepspeech.modules.ctc import CTCDecoder from deepspeech.modules.decoder import TransformerDecoder from deepspeech.modules.label_smoothing_loss import LabelSmoothingLoss -from deepspeech.modules.mask import make_pad_mask -from deepspeech.modules.mask import mask_finished_preds -from deepspeech.modules.mask import mask_finished_scores -from deepspeech.modules.mask import subsequent_mask from deepspeech.utils import checkpoint from deepspeech.utils import layer_tools @@ -101,6 +102,8 @@ class U2Model(nn.Module): speech_lengths: (Batch, ) text: (Batch, Length) text_lengths: (Batch,) + Returns: + total_loss, attention_loss, ctc_loss """ assert text_lengths.dim() == 1, text_lengths.shape # Check that batch_size is unified @@ -109,21 +112,19 @@ class U2Model(nn.Module): text.shape, text_lengths.shape) # 1. Encoder encoder_out, encoder_mask = self.encoder(speech, speech_lengths) - encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] # 2a. Attention-decoder branch + loss_att = None if self.ctc_weight != 1.0: loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, text, text_lengths) - else: - loss_att = None # 2b. CTC branch + loss_ctc = None if self.ctc_weight != 0.0: loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths) - else: - loss_ctc = None if loss_ctc is None: loss = loss_att @@ -139,6 +140,17 @@ class U2Model(nn.Module): encoder_mask: paddle.Tensor, ys_pad: paddle.Tensor, ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: + """Calc attention loss. + + Args: + encoder_out (paddle.Tensor): [B, Tmax, D] + encoder_mask (paddle.Tensor): [B, 1, Tmax] + ys_pad (paddle.Tensor): [B, Umax] + ys_pad_lens (paddle.Tensor): [B] + + Returns: + Tuple[paddle.Tensor, float]: attention_loss, accuracy rate + """ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 @@ -163,6 +175,20 @@ class U2Model(nn.Module): num_decoding_left_chunks: int=-1, simulate_streaming: bool=False, ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Encoder pass. + + Args: + speech (paddle.Tensor): [B, Tmax, D] + speech_lengths (paddle.Tensor): [B] + decoding_chunk_size (int, optional): chuck size. Defaults to -1. + num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1. + simulate_streaming (bool, optional): streaming or not. Defaults to False. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: + encoder hiddens (B, Tmax, D), + encoder hiddens mask (B, 1, Tmax). + """ # Let's assume B = batch_size # 1. Encoder if simulate_streaming and decoding_chunk_size > 0: @@ -205,7 +231,7 @@ class U2Model(nn.Module): """ assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 - device = speech.device + device = speech.place batch_size = speech.shape[0] # Let's assume B = batch_size and N = beam_size @@ -223,14 +249,14 @@ class U2Model(nn.Module): 1, beam_size, 1, 1).view(running_size, 1, maxlen) # (B*N, 1, max_len) - hyps = torch.ones( - [running_size, 1], dtype=torch.long, - device=device).fill_(self.sos) # (B*N, 1) - scores = paddle.tensor( - [0.0] + [-float('inf')] * (beam_size - 1), dtype=torch.float) + hyps = paddle.ones( + [running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1) + # log scale score + scores = paddle.to_tensor( + [0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float) scores = scores.to(device).repeat([batch_size]).unsqueeze(1).to( device) # (B*N, 1) - end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device) + end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1) cache: Optional[List[paddle.Tensor]] = None # 2. Decoder forward step by step for i in range(1, maxlen + 1): diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index 4dd989d58..c0d54feb1 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -152,12 +152,12 @@ class TransformerDecoder(nn.Module): memory: encoded memory, float32 (batch, maxlen_in, feat) memory_mask: encoded memory mask, (batch, 1, maxlen_in) tgt: input token ids, int64 (batch, maxlen_out) - tgt_mask: input token mask, (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out, maxlen_out) dtype=paddle.bool cache: cached output list of (batch, max_time_out-1, size) Returns: y, cache: NN output value and cache per `self.decoders`. - y.shape` is (batch, maxlen_out, token) + y.shape` is (batch, token) """ x, _ = self.embed(tgt) new_cache = [] diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index 9e1d34a89..456562df9 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -88,7 +88,10 @@ class LabelSmoothingLoss(nn.Layer): size (int): the number of class padding_idx (int): padding class id which will be ignored for loss smoothing (float): smoothing rate (0.0 means the conventional CE) - normalize_length (bool): True, normalize loss by sequence length; False, normalize loss by batch size. Defaults to False. + normalize_length (bool): + True, normalize loss by sequence length; + False, normalize loss by batch size. + Defaults to False. """ super().__init__() self.size = size @@ -103,6 +106,7 @@ class LabelSmoothingLoss(nn.Layer): The model outputs and data labels tensors are flatten to (batch*seqlen, class) shape and a mask is applied to the padding part which should not be calculated for loss. + Args: x (paddle.Tensor): prediction (batch, seqlen, class) target (paddle.Tensor): diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index 4351a7cb8..e38d75f8f 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -50,6 +50,52 @@ def sequence_mask(x_len, max_len=None, dtype='float32'): return mask +def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: + """Make mask tensor containing indices of padded part. + See description of make_non_pad_mask. + Args: + lengths (paddle.Tensor): Batch of lengths (B,). + Returns: + paddle.Tensor: Mask tensor containing indices of padded part. + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = int(lengths.shape[0]) + max_len = int(lengths.max()) + seq_range = paddle.arange(0, max_len, dtype=paddle.int64) + seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len]) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + + +def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: + """Make mask tensor containing indices of non-padded part. + The sequences in a batch may have different lengths. To enable + batch computing, padding is need to make all sequence in same + size. To avoid the padding part pass value to context dependent + block such as attention or convolution , this padding part is + masked. + This pad_mask is used in both encoder and decoder. + 1 for non-padded part and 0 for padded part. + Args: + lengths (paddle.Tensor): Batch of lengths (B,). + Returns: + paddle.Tensor: mask tensor containing indices of padded part. + Examples: + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + """ + return ~make_pad_mask(lengths) + + def subsequent_mask(size: int) -> paddle.Tensor: """Create mask for subsequent steps (size, size). This mask is used only in decoder which works in an auto-regressive mode. @@ -170,52 +216,6 @@ def add_optional_chunk_mask(xs: paddle.Tensor, return chunk_masks -def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: - """Make mask tensor containing indices of padded part. - See description of make_non_pad_mask. - Args: - lengths (paddle.Tensor): Batch of lengths (B,). - Returns: - paddle.Tensor: Mask tensor containing indices of padded part. - Examples: - >>> lengths = [5, 3, 2] - >>> make_pad_mask(lengths) - masks = [[0, 0, 0, 0 ,0], - [0, 0, 0, 1, 1], - [0, 0, 1, 1, 1]] - """ - batch_size = int(lengths.shape[0]) - max_len = int(lengths.max()) - seq_range = paddle.arange(0, max_len, dtype=paddle.int64) - seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len]) - seq_length_expand = lengths.unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - return mask - - -def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: - """Make mask tensor containing indices of non-padded part. - The sequences in a batch may have different lengths. To enable - batch computing, padding is need to make all sequence in same - size. To avoid the padding part pass value to context dependent - block such as attention or convolution , this padding part is - masked. - This pad_mask is used in both encoder and decoder. - 1 for non-padded part and 0 for padded part. - Args: - lengths (paddle.Tensor): Batch of lengths (B,). - Returns: - paddle.Tensor: mask tensor containing indices of padded part. - Examples: - >>> lengths = [5, 3, 2] - >>> make_non_pad_mask(lengths) - masks = [[1, 1, 1, 1 ,1], - [1, 1, 1, 0, 0], - [1, 1, 0, 0, 0]] - """ - return ~make_pad_mask(lengths) - - def mask_finished_scores(score: paddle.Tensor, flag: paddle.Tensor) -> paddle.Tensor: """ diff --git a/tests/network_test.py b/tests/network_test.py index 31c5efc79..ae86c9c43 100644 --- a/tests/network_test.py +++ b/tests/network_test.py @@ -46,7 +46,7 @@ if __name__ == '__main__': rnn_size=1024, use_gru=False, share_rnn_weights=False, ) - logits, probs, logits_len = model(audio, text, audio_len, text_len) + logits, probs, logits_len = model(audio, audio_len, text, text_len) print('probs.shape', probs.shape) print("-----------------") @@ -58,7 +58,7 @@ if __name__ == '__main__': rnn_size=1024, use_gru=True, share_rnn_weights=False, ) - logits, probs, logits_len = model2(audio, text, audio_len, text_len) + logits, probs, logits_len = model2(audio, audio_len, text, text_len) print('probs.shape', probs.shape) print("-----------------") @@ -70,7 +70,7 @@ if __name__ == '__main__': rnn_size=1024, use_gru=False, share_rnn_weights=True, ) - logits, probs, logits_len = model3(audio, text, audio_len, text_len) + logits, probs, logits_len = model3(audio, audio_len, text, text_len) print('probs.shape', probs.shape) print("-----------------") @@ -82,7 +82,7 @@ if __name__ == '__main__': rnn_size=1024, use_gru=True, share_rnn_weights=True, ) - logits, probs, logits_len = model4(audio, text, audio_len, text_len) + logits, probs, logits_len = model4(audio, audio_len, text, text_len) print('probs.shape', probs.shape) print("-----------------") @@ -94,6 +94,6 @@ if __name__ == '__main__': rnn_size=1024, use_gru=False, share_rnn_weights=False, ) - logits, probs, logits_len = model5(audio, text, audio_len, text_len) + logits, probs, logits_len = model5(audio, audio_len, text, text_len) print('probs.shape', probs.shape) print("-----------------")