From 3ee6aed57d681040a98e4f8b78b20f8b7f4c0cdf Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Sat, 23 Oct 2021 10:40:39 +0000 Subject: [PATCH] using the subsamping cache --- deepspeech/models/lm/transformer.py | 67 +++++++++++++++-------------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/deepspeech/models/lm/transformer.py b/deepspeech/models/lm/transformer.py index 8035ee631..c506f5772 100644 --- a/deepspeech/models/lm/transformer.py +++ b/deepspeech/models/lm/transformer.py @@ -56,7 +56,7 @@ class TransformerLM(nn.Layer): concat_after=False, static_chunk_size=1, use_dynamic_chunk=False, - use_dynamic_left_chunk=True, ) + use_dynamic_left_chunk=False, ) self.decoder = nn.Linear(att_unit, vocab_size) @@ -66,13 +66,6 @@ class TransformerLM(nn.Layer): model_dict = paddle.load("transformerLM.pdparams") self.set_state_dict(model_dict) - def _target_len(self, ys_in_pad): - ys_len_tmp = paddle.where( - paddle.to_tensor(ys_in_pad != 0), - paddle.ones_like(ys_in_pad), paddle.zeros_like(ys_in_pad)) - ys_len = paddle.sum(ys_len_tmp, axis=-1) - return ys_len - def forward(self, input: paddle.Tensor, hidden: None) -> Tuple[paddle.Tensor, None]: @@ -85,12 +78,12 @@ class TransformerLM(nn.Layer): def score( self, y: paddle.Tensor, + subsampling_cache, state: Any, - x: paddle.Tensor, ) -> Tuple[paddle.Tensor, Any]: + offset: int, ) -> Tuple[paddle.Tensor, Any]: # y, the chunk input y = y.unsqueeze(0) - offset = 0 - subsampling_cache = None + subsampling_cache = subsampling_cache conformer_cnn_cache = None elayers_output_cache = state required_cache_size = -1 @@ -100,23 +93,25 @@ class TransformerLM(nn.Layer): elayers_output_cache, conformer_cnn_cache) h = self.decoder(h[:, -1]) logp = F.log_softmax(h).squeeze(0) - return h, r_elayers_output_cache + return h, r_subsampling_cache, r_elayers_output_cache - def batch_score(self, - ys: paddle.Tensor, - states: List[Any], - xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]: + def batch_score( + self, + ys: paddle.Tensor, + subsampling_caches: List[Any], + encoder_states: List[Any], + offset: int, ) -> Tuple[paddle.Tensor, List[Any]]: + #ys, the batch chunk input n_batch = ys.shape[0] n_layers = len(self.encoder.encoders) hs = [] - new_states = [] + new_subsampling_states = [] + new_encoder_states = [] for i in range(n_batch): y = ys[i:i + 1, :] - state = states[i] - offset = 0 - subsampling_cache = None + subsampling_cache = subsampling_caches[i] + elayers_output_cache = encoder_states[i] conformer_cnn_cache = None - elayers_output_cache = state required_cache_size = -1 y = self.embed(y) h, r_subsampling_cache, r_elayers_output_cache, r_conformer_cnn_cache = self.encoder.forward_chunk( @@ -124,10 +119,11 @@ class TransformerLM(nn.Layer): elayers_output_cache, conformer_cnn_cache) h = self.decoder(h[:, -1]) hs.append(h) - new_states.append(r_elayers_output_cache) + new_subsampling_states.append(r_subsampling_cache) + new_encoder_states.append(r_elayers_output_cache) hs = paddle.concat(hs, axis=0) hs = F.log_softmax(hs) - return hs, new_states + return hs, new_subsampling_states, new_encoder_states if __name__ == "__main__": @@ -144,26 +140,33 @@ if __name__ == "__main__": tlm.eval() """ + #Test the score input2 = np.array([5]) input2 = paddle.to_tensor(input2) - output, cache =tlm.score(input2, None, None) + output, sub_cache, cache =tlm.score(input2, None, None, 0) - input3 = np.array([5, 10]) + input3 = np.array([10]) input3 = paddle.to_tensor(input3) - output, cache = tlm.score(input3, cache, None) + output, sub_cache, cache = tlm.score(input3, sub_cache, cache, 1) - input4 = np.array([5, 10, 7]) + input4 = np.array([7]) input4 = paddle.to_tensor(input4) - output, cache = tlm.score(input4, cache, None) + output, sub_cache, cache = tlm.score(input4, sub_cache, cache, 2) print ("output", output) """ + #Test the batch score + batch_size = 2 + offset = 0 inp2 = np.array([[5], [10]]) inp2 = paddle.to_tensor(inp2) - output, cache = tlm.batch_score(inp2, [None] * 4, None) + output, subsampling_caches, encoder_caches = tlm.batch_score( + inp2, [None] * batch_size, [None] * batch_size, offset) - inp3 = np.array([[5, 100], [10, 30]]) + offset += 1 + inp3 = np.array([[100], [30]]) inp3 = paddle.to_tensor(inp3) - output, cache = tlm.batch_score(inp3, cache, None) + output, subsampling_caches, encoder_caches = tlm.batch_score( + inp3, subsampling_caches, encoder_caches, offset) print("output", output) - print("cache", cache) + #print("cache", cache) #np.save("output_pd.npy", output)