tensor.size to tensor.shape

pull/841/head
Hui Zhang 3 years ago
parent cd001d5daf
commit 000183ea49

@ -579,7 +579,7 @@ class U2Tester(U2Trainer):
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)

@ -557,7 +557,7 @@ class U2Tester(U2Trainer):
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)

@ -588,7 +588,7 @@ class U2STTester(U2STTrainer):
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)

@ -298,8 +298,8 @@ class U2BaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
encoder_dim = encoder_out.size(2)
maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.shape[2]
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
@ -404,7 +404,7 @@ class U2BaseModel(nn.Layer):
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
maxlen = encoder_out.size(1)
maxlen = encoder_out.shape[1]
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
@ -455,7 +455,7 @@ class U2BaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
maxlen = encoder_out.shape[1]
ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
@ -583,7 +583,7 @@ class U2BaseModel(nn.Layer):
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.size(1)), dtype=paddle.bool)
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
@ -690,13 +690,13 @@ class U2BaseModel(nn.Layer):
Returns:
paddle.Tensor: decoder output, (B, L)
"""
assert encoder_out.size(0) == 1
num_hyps = hyps.size(0)
assert hyps_lens.size(0) == num_hyps
assert encoder_out.shape[0] == 1
num_hyps = hyps.shape[0]
assert hyps_lens.shape[0] == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
# (B, 1, T)
encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool)
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens)
@ -751,7 +751,7 @@ class U2BaseModel(nn.Layer):
Returns:
List[List[int]]: transcripts.
"""
batch_size = feats.size(0)
batch_size = feats.shape[0]
if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1:
logger.fatal(
@ -779,7 +779,7 @@ class U2BaseModel(nn.Layer):
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search':
assert feats.size(0) == 1
assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search(
feats,
feats_lengths,
@ -789,7 +789,7 @@ class U2BaseModel(nn.Layer):
simulate_streaming=simulate_streaming)
hyps = [hyp]
elif decoding_method == 'attention_rescoring':
assert feats.size(0) == 1
assert feats.shape[0] == 1
hyp = self.attention_rescoring(
feats,
feats_lengths,

@ -340,8 +340,8 @@ class U2STBaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
encoder_dim = encoder_out.size(2)
maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.shape[2]
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
@ -496,13 +496,13 @@ class U2STBaseModel(nn.Layer):
Returns:
paddle.Tensor: decoder output, (B, L)
"""
assert encoder_out.size(0) == 1
num_hyps = hyps.size(0)
assert hyps_lens.size(0) == num_hyps
assert encoder_out.shape[0] == 1
num_hyps = hyps.shape[0]
assert hyps_lens.shape[0] == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
# (B, 1, T)
encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool)
[num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens)
@ -557,7 +557,7 @@ class U2STBaseModel(nn.Layer):
Returns:
List[List[int]]: transcripts.
"""
batch_size = feats.size(0)
batch_size = feats.shape[0]
if decoding_method == 'fullsentence':
hyps = self.translate(

@ -70,7 +70,7 @@ class MultiHeadedAttention(nn.Layer):
paddle.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
n_batch = query.shape[0]
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
@ -96,7 +96,7 @@ class MultiHeadedAttention(nn.Layer):
paddle.Tensor: Transformed value weighted
by the attention score, (#batch, time1, d_model).
"""
n_batch = value.size(0)
n_batch = value.shape[0]
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
@ -172,15 +172,16 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
paddle.Tensor: Output tensor. (batch, head, time1, time1)
"""
zero_pad = paddle.zeros(
(x.size(0), x.size(1), x.size(2), 1), dtype=x.dtype)
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2))
x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
if zero_triu:
ones = paddle.ones((x.size(2), x.size(3)))
x = x * paddle.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
ones = paddle.ones((x.shape[2], x.shape[3]))
x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :]
return x
@ -205,7 +206,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)

@ -122,7 +122,7 @@ class TransformerDecoder(nn.Layer):
# tgt_mask: (B, 1, L)
tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1))
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0)
m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m

@ -68,7 +68,7 @@ class PositionalEncoding(nn.Layer):
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
"""
T = x.shape[1]
assert offset + x.size(1) < self.max_len
assert offset + x.shape[1] < self.max_len
#TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + T]
x = x * self.xscale + pos_emb
@ -114,7 +114,7 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
assert offset + x.size(1) < self.max_len
assert offset + x.shape[1] < self.max_len
x = x * self.xscale
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]]

@ -206,11 +206,11 @@ class BaseEncoder(nn.Layer):
chunk computation
List[paddle.Tensor]: conformer cnn cache
"""
assert xs.size(0) == 1 # batch size must be one
assert xs.shape[0] == 1 # batch size must be one
# tmp_masks is just for interface compatibility
# TODO(Hui Zhang): stride_slice not support bool tensor
# tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.int32)
tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32)
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]
if self.global_cmvn is not None:
@ -220,25 +220,25 @@ class BaseEncoder(nn.Layer):
xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D)
if subsampling_cache is not None:
cache_size = subsampling_cache.size(1) #T
cache_size = subsampling_cache.shape[1] #T
xs = paddle.cat((subsampling_cache, xs), dim=1)
else:
cache_size = 0
# only used when using `RelPositionMultiHeadedAttention`
pos_emb = self.embed.position_encoding(
offset=offset - cache_size, size=xs.size(1))
offset=offset - cache_size, size=xs.shape[1])
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = xs.size(1)
next_cache_start = xs.shape[1]
else:
next_cache_start = xs.size(1) - required_cache_size
next_cache_start = xs.shape[1] - required_cache_size
r_subsampling_cache = xs[:, next_cache_start:, :]
# Real mask for transformer/conformer layers
masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1) #[B=1, L'=1, T]
r_elayers_output_cache = []
r_conformer_cnn_cache = []
@ -302,7 +302,7 @@ class BaseEncoder(nn.Layer):
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.size(1)
num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
subsampling_cache: Optional[paddle.Tensor] = None
elayers_output_cache: Optional[List[paddle.Tensor]] = None
@ -318,10 +318,10 @@ class BaseEncoder(nn.Layer):
chunk_xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
outputs.append(y)
offset += y.size(1)
offset += y.shape[1]
ys = paddle.cat(outputs, 1)
# fake mask, just for jit script and compatibility with `forward` api
masks = paddle.ones([1, ys.size(1)], dtype=paddle.bool)
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
return ys, masks

@ -84,11 +84,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
y_insert_blank = insert_blank(y, blank_id) #(2L+1)
log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
(ctc_probs.shape[0], len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero
# TODO(Hui Zhang): zeros not support paddle.int16
state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1
(ctc_probs.shape[0], len(y_insert_blank)), dtype=paddle.int32) - 1
) # state path, Tuple((T, 2L+1))
# init start state
@ -96,7 +96,7 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
for t in range(1, ctc_probs.size(0)): # T
for t in range(1, ctc_probs.shape[0]): # T
for s in range(len(y_insert_blank)): # 2L+1
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]:
@ -116,7 +116,7 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
state_path[t, s] = prev_state[paddle.argmax(candidates)]
# TODO(Hui Zhang): zeros not support paddle.int16
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32)
state_seq = -1 * paddle.ones((ctc_probs.shape[0], 1), dtype=paddle.int32)
candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb
@ -124,11 +124,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
])
prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
state_seq[-1] = prev_state[paddle.argmax(candidates)]
for t in range(ctc_probs.size(0) - 2, -1, -1):
for t in range(ctc_probs.shape[0] - 2, -1, -1):
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
output_alignment = []
for t in range(0, ctc_probs.size(0)):
for t in range(0, ctc_probs.shape[0]):
output_alignment.append(y_insert_blank[state_seq[t, 0]])
return output_alignment

@ -83,7 +83,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims = max_size[1:] if max_size.ndim >= 2 else ()
max_len = max([s.size(0) for s in sequences])
max_len = max([s.shape[0] for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
@ -91,7 +91,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
@ -139,7 +139,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B = ys_pad.size(0)
B = ys_pad.shape[0]
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
ys_in = paddle.cat([_sos, ys_pad], dim=1)
@ -165,8 +165,8 @@ def th_accuracy(pad_outputs: paddle.Tensor,
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2)
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
pad_outputs.shape[1]).argmax(2)
mask = pad_targets != ignore_label
numerator = paddle.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))

Loading…
Cancel
Save