|
|
|
@ -159,7 +159,7 @@ class BaseEncoder(nn.Layer):
|
|
|
|
|
if self.global_cmvn is not None:
|
|
|
|
|
xs = self.global_cmvn(xs)
|
|
|
|
|
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
|
|
|
|
|
xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)
|
|
|
|
|
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
|
|
|
|
|
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
|
|
|
|
|
masks = masks.astype(paddle.bool)
|
|
|
|
|
mask_pad = ~masks
|
|
|
|
|