From 259781768e17b9147bd1af7129ba6f141fd89d4c Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 7 Jul 2021 09:16:05 +0000 Subject: [PATCH] comment u2 model for easy understand --- deepspeech/models/u2.py | 41 +++++++++++++++++++++++------------ deepspeech/modules/encoder.py | 5 ++++- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 6b266bdb..f1d466a2 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """U2 ASR Model -Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition +Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition (https://arxiv.org/pdf/2012.05481.pdf) """ import sys @@ -83,7 +83,7 @@ class U2BaseModel(nn.Module): # cnn_module_kernel=15, # activation_type='swish', # pos_enc_layer_type='rel_pos', - # selfattention_layer_type='rel_selfattn', + # selfattention_layer_type='rel_selfattn', )) # decoder related default.decoder = 'transformer' @@ -244,8 +244,8 @@ class U2BaseModel(nn.Module): simulate_streaming (bool, optional): streaming or not. Defaults to False. Returns: - Tuple[paddle.Tensor, paddle.Tensor]: - encoder hiddens (B, Tmax, D), + Tuple[paddle.Tensor, paddle.Tensor]: + encoder hiddens (B, Tmax, D), encoder hiddens mask (B, 1, Tmax). """ # Let's assume B = batch_size @@ -399,6 +399,7 @@ class U2BaseModel(nn.Module): assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 batch_size = speech.shape[0] + # Let's assume B = batch_size # encoder_out: (B, maxlen, encoder_dim) # encoder_mask: (B, 1, Tmax) @@ -410,10 +411,12 @@ class U2BaseModel(nn.Module): # encoder_out_lens = encoder_mask.squeeze(1).sum(1) encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) + topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen) topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] return hyps @@ -449,6 +452,7 @@ class U2BaseModel(nn.Module): batch_size = speech.shape[0] # For CTC prefix beam search, we only support batch_size=1 assert batch_size == 1 + # Let's assume B = batch_size and N = beam_size # 1. Encoder forward and get CTC score encoder_out, encoder_mask = self._forward_encoder( @@ -458,7 +462,9 @@ class U2BaseModel(nn.Module): maxlen = encoder_out.size(1) ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + # blank_ending_score and none_blank_ending_score in ln domain cur_hyps = [(tuple(), (0.0, -float('inf')))] # 2. CTC beam search step by step for t in range(0, maxlen): @@ -498,6 +504,7 @@ class U2BaseModel(nn.Module): key=lambda x: log_add(list(x[1])), reverse=True) cur_hyps = next_hyps[:beam_size] + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] return hyps, encoder_out @@ -561,12 +568,13 @@ class U2BaseModel(nn.Module): batch_size = speech.shape[0] # For attention rescoring we only support batch_size=1 assert batch_size == 1 - # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size + + # len(hyps) = beam_size, encoder_out: (1, maxlen, encoder_dim) hyps, encoder_out = self._ctc_prefix_beam_search( speech, speech_lengths, beam_size, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) - assert len(hyps) == beam_size + hyps_pad = pad_sequence([ paddle.to_tensor(hyp[0], place=device, dtype=paddle.long) for hyp in hyps @@ -576,23 +584,28 @@ class U2BaseModel(nn.Module): dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining + encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( (beam_size, 1, encoder_out.size(1)), dtype=paddle.bool) decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = decoder_out.numpy() + # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size for i, hyp in enumerate(hyps): score = 0.0 for j, w in enumerate(hyp[0]): score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.eos] - # add ctc score + # add ctc score (which in ln domain) score += hyp[1] * ctc_weight if score > best_score: best_score = score @@ -719,8 +732,8 @@ class U2BaseModel(nn.Module): feats (Tenosr): audio features, (B, T, D) feats_lengths (Tenosr): (B) text_feature (TextFeaturizer): text feature object. - decoding_method (str): decoding mode, e.g. - 'attention', 'ctc_greedy_search', + decoding_method (str): decoding mode, e.g. + 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' lang_model_path (str): lm path. beam_alpha (float): lm weight. @@ -728,19 +741,19 @@ class U2BaseModel(nn.Module): beam_size (int): beam size for search cutoff_prob (float): for prune. cutoff_top_n (int): for prune. - num_processes (int): + num_processes (int): ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0. decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. - 0: used for training, it's prohibited here. - num_decoding_left_chunks (int, optional): + 0: used for training, it's prohibited here. + num_decoding_left_chunks (int, optional): number of left chunks for decoding. Defaults to -1. simulate_streaming (bool, optional): simulate streaming inference. Defaults to False. Raises: ValueError: when not support decoding_method. - + Returns: List[List[int]]: transcripts. """ @@ -821,7 +834,7 @@ class U2Model(U2BaseModel): ValueError: raise when using not support encoder type. Returns: - int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc + int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc """ if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index e326db8f..27e0f8d7 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -219,11 +219,14 @@ class BaseEncoder(nn.Layer): xs, pos_emb, _ = self.embed( xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D) + if subsampling_cache is not None: cache_size = subsampling_cache.size(1) #T xs = paddle.cat((subsampling_cache, xs), dim=1) else: cache_size = 0 + + # only used when using `RelPositionMultiHeadedAttention` pos_emb = self.embed.position_encoding( offset=offset - cache_size, size=xs.size(1)) @@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer): # Real mask for transformer/conformer layers masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) - masks = masks.unsqueeze(1) #[B=1, C=1, T] + masks = masks.unsqueeze(1) #[B=1, L'=1, T] r_elayers_output_cache = [] r_conformer_cnn_cache = [] for i, layer in enumerate(self.encoders):