From ef959bb49d6868ed0df504f8702df7bdf5120562 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Sun, 24 Oct 2021 11:21:24 +0000 Subject: [PATCH] Add the None for pos_enc and feed_one_step --- deepspeech/models/lm/transformer.py | 77 +++++++++++------------- deepspeech/modules/embedding.py | 16 +++++ deepspeech/modules/encoder.py | 90 ++++++++++++++++++++++++++++- 3 files changed, 139 insertions(+), 44 deletions(-) diff --git a/deepspeech/models/lm/transformer.py b/deepspeech/models/lm/transformer.py index c506f5772..b5f51cbf5 100644 --- a/deepspeech/models/lm/transformer.py +++ b/deepspeech/models/lm/transformer.py @@ -22,6 +22,8 @@ import paddle.nn.functional as F from deepspeech.modules.encoder import TransformerEncoder +#LMInterface, BatchScorerInterface + class TransformerLM(nn.Layer): def __init__( @@ -39,7 +41,7 @@ class TransformerLM(nn.Layer): pos_enc_layer_type = "abs_pos" elif pos_enc is None: #TODO - raise ValueError(f"unknown pos-enc option: {pos_enc}") + pos_enc_layer_type = "None" else: raise ValueError(f"unknown pos-enc option: {pos_enc}") @@ -66,8 +68,15 @@ class TransformerLM(nn.Layer): model_dict = paddle.load("transformerLM.pdparams") self.set_state_dict(model_dict) + def _target_len(self, ys_in_pad): + ys_len_tmp = paddle.where( + paddle.to_tensor(ys_in_pad != 0), + paddle.ones_like(ys_in_pad), paddle.zeros_like(ys_in_pad)) + ys_len = paddle.sum(ys_len_tmp, axis=-1) + return ys_len + def forward(self, input: paddle.Tensor, - hidden: None) -> Tuple[paddle.Tensor, None]: + x_len: paddle.Tensor) -> Tuple[paddle.Tensor, None]: x = self.embed(input) x_len = self._target_len(input) @@ -75,61 +84,46 @@ class TransformerLM(nn.Layer): y = self.decoder(h) return y, None - def score( - self, - y: paddle.Tensor, - subsampling_cache, - state: Any, - offset: int, ) -> Tuple[paddle.Tensor, Any]: + def score(self, y: paddle.Tensor, state: Any, + x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]: # y, the chunk input y = y.unsqueeze(0) - subsampling_cache = subsampling_cache - conformer_cnn_cache = None - elayers_output_cache = state + #subsampling_cache, elayers_output_cache, conformer_cnn_cache, offset = state required_cache_size = -1 y = self.embed(y) - h, r_subsampling_cache, r_elayers_output_cache, r_conformer_cnn_cache = self.encoder.forward_chunk( - y, offset, required_cache_size, subsampling_cache, - elayers_output_cache, conformer_cnn_cache) + h, state = self.encoder.forward_one_step(y, required_cache_size, state) h = self.decoder(h[:, -1]) logp = F.log_softmax(h).squeeze(0) - return h, r_subsampling_cache, r_elayers_output_cache + return h, state def batch_score( self, ys: paddle.Tensor, - subsampling_caches: List[Any], - encoder_states: List[Any], - offset: int, ) -> Tuple[paddle.Tensor, List[Any]]: + states: List[Any], ) -> Tuple[paddle.Tensor, List[Any]]: #ys, the batch chunk input n_batch = ys.shape[0] n_layers = len(self.encoder.encoders) hs = [] - new_subsampling_states = [] - new_encoder_states = [] + new_states = [] for i in range(n_batch): y = ys[i:i + 1, :] - subsampling_cache = subsampling_caches[i] - elayers_output_cache = encoder_states[i] - conformer_cnn_cache = None + state = states[i] required_cache_size = -1 y = self.embed(y) - h, r_subsampling_cache, r_elayers_output_cache, r_conformer_cnn_cache = self.encoder.forward_chunk( - y, offset, required_cache_size, subsampling_cache, - elayers_output_cache, conformer_cnn_cache) + h, state = self.encoder.forward_one_step(y, required_cache_size, + state) h = self.decoder(h[:, -1]) hs.append(h) - new_subsampling_states.append(r_subsampling_cache) - new_encoder_states.append(r_elayers_output_cache) + new_states.append(state) hs = paddle.concat(hs, axis=0) hs = F.log_softmax(hs) - return hs, new_subsampling_states, new_encoder_states + return hs, new_states if __name__ == "__main__": tlm = TransformerLM( vocab_size=5002, - pos_enc='sinusoidal', + pos_enc=None, embed_unit=128, att_unit=512, head=8, @@ -139,34 +133,33 @@ if __name__ == "__main__": paddle.set_device("cpu") tlm.eval() - """ #Test the score input2 = np.array([5]) input2 = paddle.to_tensor(input2) - output, sub_cache, cache =tlm.score(input2, None, None, 0) + state = (None, None, 0) + output, state = tlm.score(input2, state, None) input3 = np.array([10]) input3 = paddle.to_tensor(input3) - output, sub_cache, cache = tlm.score(input3, sub_cache, cache, 1) + output, state = tlm.score(input3, state, None) - input4 = np.array([7]) + input4 = np.array([0]) input4 = paddle.to_tensor(input4) - output, sub_cache, cache = tlm.score(input4, sub_cache, cache, 2) - print ("output", output) + output, state = tlm.score(input4, state, None) + print("output", output) """ #Test the batch score batch_size = 2 - offset = 0 inp2 = np.array([[5], [10]]) inp2 = paddle.to_tensor(inp2) - output, subsampling_caches, encoder_caches = tlm.batch_score( - inp2, [None] * batch_size, [None] * batch_size, offset) + output, states = tlm.batch_score( + inp2, [(None,None,0)] * batch_size) - offset += 1 inp3 = np.array([[100], [30]]) inp3 = paddle.to_tensor(inp3) - output, subsampling_caches, encoder_caches = tlm.batch_score( - inp3, subsampling_caches, encoder_caches, offset) + output, states = tlm.batch_score( + inp3, states) print("output", output) #print("cache", cache) #np.save("output_pd.npy", output) + """ diff --git a/deepspeech/modules/embedding.py b/deepspeech/modules/embedding.py index fbbda023c..43c40b6d4 100644 --- a/deepspeech/modules/embedding.py +++ b/deepspeech/modules/embedding.py @@ -25,6 +25,22 @@ logger = Log(__name__).getlog() __all__ = ["PositionalEncoding", "RelPositionalEncoding"] +class NoPositionalEncoding(nn.Layer): + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int=5000, + reverse: bool=False): + super().__init__() + + def forward(self, x: paddle.Tensor, + offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: + return x, None + + def position_encoding(self, offset: int, size: int) -> paddle.Tensor: + return None + + class PositionalEncoding(nn.Layer): def __init__(self, d_model: int, diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index 6ffb6465c..02d58dd4b 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -24,6 +24,7 @@ from deepspeech.modules.activation import get_activation from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.attention import RelPositionMultiHeadedAttention from deepspeech.modules.conformer_convolution import ConvolutionModule +from deepspeech.modules.embedding import NoPositionalEncoding from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.embedding import RelPositionalEncoding from deepspeech.modules.encoder_layer import ConformerEncoderLayer @@ -101,6 +102,8 @@ class BaseEncoder(nn.Layer): pos_enc_class = PositionalEncoding elif pos_enc_layer_type == "rel_pos": pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "None": + pos_enc_class = NoPositionalEncoding else: raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) @@ -155,11 +158,11 @@ class BaseEncoder(nn.Layer): encoder output tensor, lens and mask """ masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L) - if self.global_cmvn is not None: xs = self.global_cmvn(xs) #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) + #print("xs", xs) #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor masks = masks.astype(paddle.bool) #TODO(Hui Zhang): mask_pad = ~masks @@ -168,8 +171,15 @@ class BaseEncoder(nn.Layer): xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, self.static_chunk_size, num_decoding_left_chunks) + #print ("chunk_masks", chunk_masks) + i = 0 for layer in self.encoders: - xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + if i == 3: + xs, chunk_masks, _ = layer( + xs, chunk_masks, pos_emb, mask_pad, is_print=True) + else: + xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + i += 1 if self.normalize_before: xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just @@ -248,6 +258,8 @@ class BaseEncoder(nn.Layer): i] cnn_cache = None if conformer_cnn_cache is None else conformer_cnn_cache[ i] + #print ("i", i) + #print ("xs", xs) xs, _, new_cnn_cache = layer( xs, masks, @@ -370,6 +382,80 @@ class TransformerEncoder(BaseEncoder): concat_after=concat_after) for _ in range(num_blocks) ]) + def forward_one_step( + self, + xs: paddle.Tensor, + required_cache_size: int, + state=(None, None, 0), + ) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[ + paddle.Tensor]]: + """ Forward just one chunk + Args: + xs (paddle.Tensor): chunk input, [B=1, T, D] + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + subsampling_cache (Optional[paddle.Tensor]): subsampling cache + elayers_output_cache (Optional[List[paddle.Tensor]]): + transformer/conformer encoder layers output cache + conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer + cnn cache + Returns: + paddle.Tensor: output of current input xs + paddle.Tensor: subsampling cache required for next chunk computation + List[paddle.Tensor]: encoder layers output cache required for next + chunk computation + List[paddle.Tensor]: conformer cnn cache + """ + assert xs.shape[0] == 1 # batch size must be one + # tmp_masks is just for interface compatibility + # TODO(Hui Zhang): stride_slice not support bool tensor + # tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) + subsampling_cache, elayers_output_cache, offset = state + tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32) + tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T] + + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + + xs, pos_emb, _ = self.embed( + xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D) + + if subsampling_cache is not None: + cache_size = subsampling_cache.shape[1] #T + xs = paddle.cat((subsampling_cache, xs), dim=1) + else: + cache_size = 0 + + # only used when using `RelPositionMultiHeadedAttention` + pos_emb = self.embed.position_encoding( + offset=offset - cache_size, size=xs.shape[1]) + + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = xs.shape[1] + else: + next_cache_start = xs.shape[1] - required_cache_size + r_subsampling_cache = xs[:, next_cache_start:, :] + + # Real mask for transformer/conformer layers + masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool) + masks = masks.unsqueeze(1) #[B=1, L'=1, T] + r_elayers_output_cache = [] + for i, layer in enumerate(self.encoders): + attn_cache = None if elayers_output_cache is None else elayers_output_cache[ + i] + xs, _, _ = layer( + xs, masks, pos_emb, output_cache=attn_cache, cnn_cache=None) + r_elayers_output_cache.append(xs[:, next_cache_start:, :]) + if self.normalize_before: + xs = self.after_norm(xs) + new_state = (r_subsampling_cache, r_elayers_output_cache, offset + 1) + return (xs[:, cache_size:, :], new_state) + class ConformerEncoder(BaseEncoder): """Conformer encoder module."""