tensor.shape => paddle.shape(tensor)

pull/1950/head
huangyuxin 3 years ago
parent 4c09927f61
commit b23bde8ec5

@ -200,7 +200,7 @@ if not hasattr(paddle.Tensor, 'view'):
def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
return xs.reshape(ys.shape)
return xs.reshape(paddle.shape(ys))
if not hasattr(paddle.Tensor, 'view_as'):

@ -231,7 +231,7 @@ class BeamSearch(paddle.nn.Layer):
"""
# no pre beam performed, `ids` equal to `weighted_scores`
if weighted_scores.shape[0] == ids.shape[0]:
if paddle.shape(weighted_scores)[0] == paddle.shape(ids)[0]:
top_ids = weighted_scores.topk(
self.beam_size)[1] # index in n_vocab
return top_ids, top_ids
@ -370,13 +370,13 @@ class BeamSearch(paddle.nn.Layer):
"""
# set length bounds
if maxlenratio == 0:
maxlen = x.shape[0]
maxlen = paddle.shape(x)[0]
elif maxlenratio < 0:
maxlen = -1 * int(maxlenratio)
else:
maxlen = max(1, int(maxlenratio * x.shape[0]))
minlen = int(minlenratio * x.shape[0])
logger.info("decoder input length: " + str(x.shape[0]))
maxlen = max(1, int(maxlenratio * paddle.shape(x)[0]))
minlen = int(minlenratio * paddle.shape(x)[0])
logger.info("decoder input length: " + str(paddle.shape(x)[0]))
logger.info("max output length: " + str(maxlen))
logger.info("min output length: " + str(minlen))

@ -69,7 +69,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
return sc[i], st[i]
else: # for CTCPrefixScorePD (need new_id > 0)
r, log_psi, f_min, f_max, scoring_idmap = state
s = log_psi[i, new_id].expand(log_psi.shape[1])
s = log_psi[i, new_id].expand(paddle.shape(log_psi)[1])
if scoring_idmap is not None:
return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
else:
@ -107,7 +107,7 @@ class CTCPrefixScorer(BatchPartialScorerInterface):
"""
logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
xlen = paddle.to_tensor([logp.shape[1]])
xlen = paddle.to_tensor([paddle.shape(logp)[1]])
self.impl = CTCPrefixScorePD(logp, xlen, 0, self.eos)
return None

@ -33,9 +33,9 @@ class CTCPrefixScorePD():
self.logzero = -10000000000.0
self.blank = blank
self.eos = eos
self.batch = x.shape[0]
self.input_length = x.shape[1]
self.odim = x.shape[2]
self.batch = paddle.shape(x)[0]
self.input_length = paddle.shape(x)[1]
self.odim = paddle.shape(x)[2]
self.dtype = x.dtype
# Pad the rest of posteriors in the batch
@ -76,7 +76,7 @@ class CTCPrefixScorePD():
last_ids = [yi[-1] for yi in y] # last output label ids
n_bh = len(last_ids) # batch * hyps
n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
self.scoring_num = scoring_ids.shape[-1] if scoring_ids is not None else 0
self.scoring_num = paddle.shape(scoring_ids)[-1] if scoring_ids is not None else 0
# prepare state info
if state is None:
r_prev = paddle.full(
@ -226,7 +226,7 @@ class CTCPrefixScorePD():
if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
# Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops
xlens = [x.shape[1]]
xlens = [paddle.shape(x)[1]]
for i, l in enumerate(xlens):
if l < self.input_length:
x[i, l:, :] = self.logzero
@ -236,7 +236,7 @@ class CTCPrefixScorePD():
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
self.x = paddle.stack([xn, xb]) # (2, T, B, O)
self.x[:, :tmp_x.shape[1], :, :] = tmp_x
self.input_length = x.shape[1]
self.input_length = paddle.shape(x)[1]
self.end_frames = paddle.to_tensor(xlens) - 1
def extend_state(self, state):

@ -90,7 +90,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def _target_mask(self, ys_in_pad):
ys_mask = ys_in_pad != 0
m = subsequent_mask(ys_mask.shape[-1])).unsqueeze(0)
m = subsequent_mask(paddle.shape(ys_mask)[-1])).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m
def forward(self, x: paddle.Tensor, t: paddle.Tensor
@ -112,7 +112,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
batch_size = x.shape[0]
batch_size = paddle.shape(x)[0]
xm = x != 0
xlen = xm.sum(axis=1)
if self.embed_drop is not None:
@ -122,7 +122,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
h, _ = self.encoder(emb, xlen)
y = self.decoder(h)
loss = F.cross_entropy(
y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
y.view(-1, paddle.shape(y)[-1]), t.view(-1), reduction="none")
mask = xm.to(loss.dtype)
logp = loss * mask.view(-1)
nll = logp.view(batch_size, -1).sum(-1)

@ -775,7 +775,7 @@ class U2DecodeModel(U2BaseModel):
"""
self.eval()
x = paddle.to_tensor(x).unsqueeze(0)
ilen = x.shape[1]
ilen = paddle.shape(x)[1]
enc_output, _ = self._forward_encoder(x, ilen)
return enc_output.squeeze(0)

@ -242,7 +242,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
]
# batch decoding
ys_mask = subsequent_mask(ys.shape[-1]).unsqueeze(0) # (B,L,L)
ys_mask = subsequent_mask(paddle.shape(ys)[-1]).unsqueeze(0) # (B,L,L)
xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T)
logp, states = self.forward_one_step(
xs, xs_mask, ys, ys_mask, cache=batch_state)

@ -115,7 +115,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
assert offset + x.shape[
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len)
#TODO(Hui Zhang): using T = x.shape[1], __getitem__ not support Tensor
#TODO(Hui Zhang): using T = paddle.shape(x)[1], __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + T]
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
@ -165,6 +165,6 @@ class RelPositionalEncoding(PositionalEncoding):
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len)
x = x * self.xscale
#TODO(Hui Zhang): using x.shape[1], __getitem__ not support Tensor
#TODO(Hui Zhang): using paddle.shape(x)[1], __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb)

@ -218,7 +218,7 @@ class BaseEncoder(nn.Layer):
assert xs.shape[0] == 1 # batch size must be one
# tmp_masks is just for interface compatibility
# TODO(Hui Zhang): stride_slice not support bool tensor
# tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool)
# tmp_masks = paddle.ones([1, paddle.shape(xs)[1]], dtype=paddle.bool)
tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32)
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]

@ -59,7 +59,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
>>> b = paddle.ones(22, 300)
>>> c = paddle.ones(15, 300)
>>> pad_sequence([a, b, c]).shape
[25, 3, 300]
paddle.Tensor([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
@ -79,7 +79,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = sequences[0].shape
max_size = paddle.shape(sequences[0])
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims = tuple(max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()

Loading…
Cancel
Save