bool logical, sum and multiply op; ctc grad norm; support old and new pd api

pull/879/head
Hui Zhang 3 years ago
parent 81f89c53e6
commit c8e96d732b

@ -41,6 +41,13 @@ def conv_output_size(I, F, P, S):
return (I - F + 2 * P - S) // S
# receptive field calculator
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
class ConvBn(nn.Layer):
"""Convolution layer with batch normalization.
@ -106,9 +113,10 @@ class ConvBn(nn.Layer):
# reset padding part to 0
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
# https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x = x * masks
# TODO(Hui Zhang): not support bool multiply
# masks = masks.type_as(x)
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len

@ -308,8 +308,8 @@ class RNNStack(nn.Layer):
x, x_len = rnn(x, x_len)
masks = make_non_pad_mask(x_len) #[B, T]
masks = masks.unsqueeze(-1) # [B, T, 1]
# https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x = x * masks
# TODO(Hui Zhang): not support bool multiply
masks = masks.astype(x.dtype)
x = x.multiply(masks)
return x, x_len

@ -164,7 +164,10 @@ class U2BaseModel(nn.Layer):
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}")
encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. Attention-decoder branch
loss_att = None
@ -319,7 +322,8 @@ class U2BaseModel(nn.Layer):
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos
if end_flag.sum() == running_size:
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step
@ -405,7 +409,9 @@ class U2BaseModel(nn.Layer):
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
maxlen = encoder_out.shape[1]
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# (TODO Hui Zhang): bool no support reduce_sum
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1)
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)

@ -165,7 +165,10 @@ class U2STBaseModel(nn.Layer):
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}")
encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. ST-decoder branch
start = time.time()
@ -362,7 +365,8 @@ class U2STBaseModel(nn.Layer):
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos
if end_flag.sum() == running_size:
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step

@ -124,7 +124,9 @@ class TransformerDecoder(nn.Layer):
# m: (1, L, L)
m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
# TODO(Hui Zhang): not support & for tensor
# tgt_mask = tgt_mask & m
tgt_mask = tgt_mask.logical_and(m)
x, _ = self.embed(tgt)
for layer in self.decoders:
@ -135,7 +137,9 @@ class TransformerDecoder(nn.Layer):
if self.use_output_layer:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
# TODO(Hui Zhang): reduce_sum not support bool type
# olens = tgt_mask.sum(1)
olens = tgt_mask.astype(paddle.int).sum(1)
return x, olens
def forward_one_step(

@ -162,7 +162,8 @@ class BaseEncoder(nn.Layer):
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
mask_pad = ~masks
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad = masks.logical_not()
chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size,

@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from functools import partial
import paddle
from paddle import nn
from paddle.nn import functional as F
@ -32,18 +35,19 @@ class CTCLoss(nn.Layer):
# last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average
logger.info(
f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
# instance for norm_by_times
# batch for norm_by_batchsize
# frame for norm_by_total_logits_len
assert grad_norm_type in ('instance', 'batch', 'frame', None)
self.norm_by_times = False
self.norm_by_batchsize = False
self.norm_by_total_logits_len = False
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
if grad_norm_type == 'instance':
if grad_norm_type is None:
# no grad norm
pass
elif grad_norm_type == 'instance':
self.norm_by_times = True
elif grad_norm_type == 'batch':
self.norm_by_batchsize = True
@ -51,6 +55,22 @@ class CTCLoss(nn.Layer):
self.norm_by_total_logits_len = True
else:
raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}")
self.kwargs = {
"norm_by_times": self.norm_by_times,
"norm_by_batchsize": self.norm_by_batchsize,
"norm_by_total_logits_len": self.norm_by_total_logits_len,
}
# Derive only the args which the func has
try:
param = inspect.signature(self.loss.forward).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
_kwargs = {k: v for k, v in self.kwargs.items() if k in param}
_notin = {k: v for k, v in self.kwargs.items() if k not in param}
logger.info(f"{self.loss} kwargs:{_kwargs}, not support: {_notin}")
self.loss_fn = partial(self.loss.forward, **_kwargs)
def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss.
@ -70,14 +90,7 @@ class CTCLoss(nn.Layer):
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2])
ys_pad = ys_pad.astype(paddle.int32)
loss = self.loss(
logits,
ys_pad,
hlens,
ys_lens,
norm_by_times=self.norm_by_times,
norm_by_batchsize=self.norm_by_batchsize,
norm_by_total_logits_len=self.norm_by_total_logits_len)
loss = self.loss_fn(logits, ys_pad, hlens, ys_lens)
if self.batch_average:
# Batch-size average
loss = loss / B
@ -118,8 +131,8 @@ class LabelSmoothingLoss(nn.Layer):
size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool):
True, normalize loss by sequence length;
normalize_length (bool):
True, normalize loss by sequence length;
False, normalize loss by batch size.
Defaults to False.
"""
@ -136,7 +149,7 @@ class LabelSmoothingLoss(nn.Layer):
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
Args:
x (paddle.Tensor): prediction (batch, seqlen, class)
target (paddle.Tensor):
@ -152,7 +165,7 @@ class LabelSmoothingLoss(nn.Layer):
# use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT
true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))
ignore = (target == self.padding_idx) # (B,)
ignore = target == self.padding_idx # (B,)
#TODO(Hui Zhang): target = target * (1 - ignore) # avoid -1 index
target = target.masked_fill(ignore, 0) # avoid -1 index
@ -163,8 +176,10 @@ class LabelSmoothingLoss(nn.Layer):
kl = self.criterion(F.log_softmax(x, axis=1), true_dist)
total = len(target) - int(ignore.sum())
#TODO(Hui Zhang): sum not support bool type
#total = len(target) - int(ignore.sum())
total = len(target) - int(ignore.type_as(target).sum())
denom = total if self.normalize_length else B
#TODO(Hui Zhang): numer = (kl * (1 - ignore)).sum()
#numer = (kl * (1 - ignore)).sum()
numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()
return numer / denom

@ -69,7 +69,8 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return ~make_pad_mask(lengths)
#return ~make_pad_mask(lengths)
return make_pad_mask(lengths).logical_not()
def subsequent_mask(size: int) -> paddle.Tensor:
@ -91,7 +92,12 @@ def subsequent_mask(size: int) -> paddle.Tensor:
[1, 1, 1]]
"""
ret = paddle.ones([size, size], dtype=paddle.bool)
return paddle.tril(ret)
#TODO(Hui Zhang): tril not support bool
#return paddle.tril(ret)
ret = ret.astype(paddle.float)
ret = paddle.tril(ret)
ret = ret.astype(paddle.bool)
return ret
def subsequent_chunk_mask(
@ -180,13 +186,15 @@ def add_optional_chunk_mask(xs: paddle.Tensor,
chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size,
num_left_chunks) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size,
num_left_chunks) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
else:
chunk_masks = masks
return chunk_masks

@ -183,7 +183,13 @@ def th_accuracy(pad_outputs: paddle.Tensor,
pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1],
pad_outputs.shape[1]).argmax(2)
mask = pad_targets != ignore_label
numerator = paddle.sum(
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = (
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = paddle.sum(mask)
numerator = paddle.sum(numerator.type_as(pad_targets))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator)

@ -37,13 +37,13 @@ class TestU2Model(unittest.TestCase):
def test_make_non_pad_mask(self):
res = make_non_pad_mask(self.lengths)
res2 = ~make_pad_mask(self.lengths)
res2 = make_pad_mask(self.lengths).logical_not()
self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist())
self.assertSequenceEqual(res.numpy().tolist(), res2.numpy().tolist())
def test_make_pad_mask(self):
res = make_pad_mask(self.lengths)
res1 = ~make_non_pad_mask(self.lengths)
res1 = make_non_pad_mask(self.lengths).logical_not()
self.assertSequenceEqual(res.numpy().tolist(), self.pad_masks.tolist())
self.assertSequenceEqual(res.numpy().tolist(), res1.tolist())

Loading…
Cancel
Save