From b23bde8ec5ff4ed3990f151246dfbb8c9dccf385 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 25 May 2022 03:30:48 +0000 Subject: [PATCH] tensor.shape => paddle.shape(tensor) --- paddlespeech/s2t/__init__.py | 2 +- paddlespeech/s2t/decoders/beam_search/beam_search.py | 10 +++++----- paddlespeech/s2t/decoders/scorers/ctc.py | 4 ++-- .../s2t/decoders/scorers/ctc_prefix_score.py | 12 ++++++------ paddlespeech/s2t/models/lm/transformer.py | 6 +++--- paddlespeech/s2t/models/u2/u2.py | 2 +- paddlespeech/s2t/modules/decoder.py | 2 +- paddlespeech/s2t/modules/embedding.py | 4 ++-- paddlespeech/s2t/modules/encoder.py | 2 +- paddlespeech/s2t/utils/tensor_utils.py | 4 ++-- 10 files changed, 24 insertions(+), 24 deletions(-) diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py index a2fce3057..2da68435c 100644 --- a/paddlespeech/s2t/__init__.py +++ b/paddlespeech/s2t/__init__.py @@ -200,7 +200,7 @@ if not hasattr(paddle.Tensor, 'view'): def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: - return xs.reshape(ys.shape) + return xs.reshape(paddle.shape(ys)) if not hasattr(paddle.Tensor, 'view_as'): diff --git a/paddlespeech/s2t/decoders/beam_search/beam_search.py b/paddlespeech/s2t/decoders/beam_search/beam_search.py index 5029e1577..f6a2b4b0a 100644 --- a/paddlespeech/s2t/decoders/beam_search/beam_search.py +++ b/paddlespeech/s2t/decoders/beam_search/beam_search.py @@ -231,7 +231,7 @@ class BeamSearch(paddle.nn.Layer): """ # no pre beam performed, `ids` equal to `weighted_scores` - if weighted_scores.shape[0] == ids.shape[0]: + if paddle.shape(weighted_scores)[0] == paddle.shape(ids)[0]: top_ids = weighted_scores.topk( self.beam_size)[1] # index in n_vocab return top_ids, top_ids @@ -370,13 +370,13 @@ class BeamSearch(paddle.nn.Layer): """ # set length bounds if maxlenratio == 0: - maxlen = x.shape[0] + maxlen = paddle.shape(x)[0] elif maxlenratio < 0: maxlen = -1 * int(maxlenratio) else: - maxlen = max(1, int(maxlenratio * x.shape[0])) - minlen = int(minlenratio * x.shape[0]) - logger.info("decoder input length: " + str(x.shape[0])) + maxlen = max(1, int(maxlenratio * paddle.shape(x)[0])) + minlen = int(minlenratio * paddle.shape(x)[0]) + logger.info("decoder input length: " + str(paddle.shape(x)[0])) logger.info("max output length: " + str(maxlen)) logger.info("min output length: " + str(minlen)) diff --git a/paddlespeech/s2t/decoders/scorers/ctc.py b/paddlespeech/s2t/decoders/scorers/ctc.py index 6f1d8c007..3c1d4cf80 100644 --- a/paddlespeech/s2t/decoders/scorers/ctc.py +++ b/paddlespeech/s2t/decoders/scorers/ctc.py @@ -69,7 +69,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface): return sc[i], st[i] else: # for CTCPrefixScorePD (need new_id > 0) r, log_psi, f_min, f_max, scoring_idmap = state - s = log_psi[i, new_id].expand(log_psi.shape[1]) + s = log_psi[i, new_id].expand(paddle.shape(log_psi)[1]) if scoring_idmap is not None: return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max else: @@ -107,7 +107,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface): """ logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 - xlen = paddle.to_tensor([logp.shape[1]]) + xlen = paddle.to_tensor([paddle.shape(logp)[1]]) self.impl = CTCPrefixScorePD(logp, xlen, 0, self.eos) return None diff --git a/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py b/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py index 0e63a52a8..d8ca5ccde 100644 --- a/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py +++ b/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py @@ -33,9 +33,9 @@ class CTCPrefixScorePD(): self.logzero = -10000000000.0 self.blank = blank self.eos = eos - self.batch = x.shape[0] - self.input_length = x.shape[1] - self.odim = x.shape[2] + self.batch = paddle.shape(x)[0] + self.input_length = paddle.shape(x)[1] + self.odim = paddle.shape(x)[2] self.dtype = x.dtype # Pad the rest of posteriors in the batch @@ -76,7 +76,7 @@ class CTCPrefixScorePD(): last_ids = [yi[-1] for yi in y] # last output label ids n_bh = len(last_ids) # batch * hyps n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps - self.scoring_num = scoring_ids.shape[-1] if scoring_ids is not None else 0 + self.scoring_num = paddle.shape(scoring_ids)[-1] if scoring_ids is not None else 0 # prepare state info if state is None: r_prev = paddle.full( @@ -226,7 +226,7 @@ class CTCPrefixScorePD(): if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O) # Pad the rest of posteriors in the batch # TODO(takaaki-hori): need a better way without for-loops - xlens = [x.shape[1]] + xlens = [paddle.shape(x)[1]] for i, l in enumerate(xlens): if l < self.input_length: x[i, l:, :] = self.logzero @@ -236,7 +236,7 @@ class CTCPrefixScorePD(): xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) self.x = paddle.stack([xn, xb]) # (2, T, B, O) self.x[:, :tmp_x.shape[1], :, :] = tmp_x - self.input_length = x.shape[1] + self.input_length = paddle.shape(x)[1] self.end_frames = paddle.to_tensor(xlens) - 1 def extend_state(self, state): diff --git a/paddlespeech/s2t/models/lm/transformer.py b/paddlespeech/s2t/models/lm/transformer.py index bb281168f..d14f99563 100644 --- a/paddlespeech/s2t/models/lm/transformer.py +++ b/paddlespeech/s2t/models/lm/transformer.py @@ -90,7 +90,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): def _target_mask(self, ys_in_pad): ys_mask = ys_in_pad != 0 - m = subsequent_mask(ys_mask.shape[-1])).unsqueeze(0) + m = subsequent_mask(paddle.shape(ys_mask)[-1])).unsqueeze(0) return ys_mask.unsqueeze(-2) & m def forward(self, x: paddle.Tensor, t: paddle.Tensor @@ -112,7 +112,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): in perplexity: p(t)^{-n} = exp(-log p(t) / n) """ - batch_size = x.shape[0] + batch_size = paddle.shape(x)[0] xm = x != 0 xlen = xm.sum(axis=1) if self.embed_drop is not None: @@ -122,7 +122,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): h, _ = self.encoder(emb, xlen) y = self.decoder(h) loss = F.cross_entropy( - y.view(-1, y.shape[-1]), t.view(-1), reduction="none") + y.view(-1, paddle.shape(y)[-1]), t.view(-1), reduction="none") mask = xm.to(loss.dtype) logp = loss * mask.view(-1) nll = logp.view(batch_size, -1).sum(-1) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index e3f46b15a..d5471369f 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -775,7 +775,7 @@ class U2DecodeModel(U2BaseModel): """ self.eval() x = paddle.to_tensor(x).unsqueeze(0) - ilen = x.shape[1] + ilen = paddle.shape(x)[1] enc_output, _ = self._forward_encoder(x, ilen) return enc_output.squeeze(0) diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index ce78059c0..ccc8482d5 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -242,7 +242,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ] # batch decoding - ys_mask = subsequent_mask(ys.shape[-1]).unsqueeze(0) # (B,L,L) + ys_mask = subsequent_mask(paddle.shape(ys)[-1]).unsqueeze(0) # (B,L,L) xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) logp, states = self.forward_one_step( xs, xs_mask, ys, ys_mask, cache=batch_state) diff --git a/paddlespeech/s2t/modules/embedding.py b/paddlespeech/s2t/modules/embedding.py index cc1fdffe2..51e558eb8 100644 --- a/paddlespeech/s2t/modules/embedding.py +++ b/paddlespeech/s2t/modules/embedding.py @@ -115,7 +115,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface): assert offset + x.shape[ 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( offset, x.shape[1], self.max_len) - #TODO(Hui Zhang): using T = x.shape[1], __getitem__ not support Tensor + #TODO(Hui Zhang): using T = paddle.shape(x)[1], __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + T] x = x * self.xscale + pos_emb return self.dropout(x), self.dropout(pos_emb) @@ -165,6 +165,6 @@ class RelPositionalEncoding(PositionalEncoding): 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( offset, x.shape[1], self.max_len) x = x * self.xscale - #TODO(Hui Zhang): using x.shape[1], __getitem__ not support Tensor + #TODO(Hui Zhang): using paddle.shape(x)[1], __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + x.shape[1]] return self.dropout(x), self.dropout(pos_emb) diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 7298c61f2..4d31acf1a 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -218,7 +218,7 @@ class BaseEncoder(nn.Layer): assert xs.shape[0] == 1 # batch size must be one # tmp_masks is just for interface compatibility # TODO(Hui Zhang): stride_slice not support bool tensor - # tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool) + # tmp_masks = paddle.ones([1, paddle.shape(xs)[1]], dtype=paddle.bool) tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32) tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T] diff --git a/paddlespeech/s2t/utils/tensor_utils.py b/paddlespeech/s2t/utils/tensor_utils.py index ca8689569..bc557b130 100644 --- a/paddlespeech/s2t/utils/tensor_utils.py +++ b/paddlespeech/s2t/utils/tensor_utils.py @@ -59,7 +59,7 @@ def pad_sequence(sequences: List[paddle.Tensor], >>> b = paddle.ones(22, 300) >>> c = paddle.ones(15, 300) >>> pad_sequence([a, b, c]).shape - [25, 3, 300] + paddle.Tensor([25, 3, 300]) Note: This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` @@ -79,7 +79,7 @@ def pad_sequence(sequences: List[paddle.Tensor], # assuming trailing dimensions and type of all the Tensors # in sequences are same and fetching those from sequences[0] - max_size = sequences[0].shape + max_size = paddle.shape(sequences[0]) # (TODO Hui Zhang): slice not supprot `end==start` # trailing_dims = max_size[1:] trailing_dims = tuple(max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()