pull/578/head
Hui Zhang 5 years ago
parent 5e7e582d7c
commit 220c9443a9

@ -168,6 +168,8 @@ if not hasattr(paddle.Tensor, 'new_full'):
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
if convert_dtype_to_string(xs.dtype) == paddle.bool:
xs = xs.astype(paddle.int)
return xs.equal(
paddle.to_tensor(
ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place))
@ -262,7 +264,7 @@ def masked_fill_(xs: paddle.Tensor,
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
ret = paddle.where(mask, trues, xs)
paddle.assign(ret, output=xs)
paddle.assign(ret.detach(), output=xs)
if not hasattr(paddle.Tensor, 'masked_fill_'):
@ -273,7 +275,7 @@ if not hasattr(paddle.Tensor, 'masked_fill_'):
def fill_(xs: paddle.Tensor, value: Union[float, int]):
val = paddle.full_like(xs, value)
paddle.assign(val, output=xs)
paddle.assign(val.detach(), output=xs)
if not hasattr(paddle.Tensor, 'fill_'):

@ -162,8 +162,8 @@ class DeepSpeech2Model(nn.Layer):
assert (self.encoder.output_size == rnn_size * 2)
self.decoder = CTCDecoder(
enc_n_units=self.encoder.output_size,
odim=dict_size + 1, # <blank> is append after vocab
enc_n_units=self.encoder.output_size,
blank_id=dict_size, # last token is <blank>
dropout_rate=0.0,
reduction=True)

@ -112,7 +112,10 @@ class U2Model(nn.Module):
text.shape, text_lengths.shape)
# 1. Encoder
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. Attention-decoder branch
loss_att = None

@ -139,7 +139,7 @@ class ConvolutionModule(nn.Layer):
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place)
new_cache = paddle.zeros([1], dtype=x.dtype)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)

@ -34,16 +34,16 @@ __all__ = ['CTCDecoder']
class CTCDecoder(nn.Layer):
def __init__(self,
enc_n_units,
odim,
enc_n_units,
blank_id=0,
dropout_rate: float=0.0,
reduction: bool=True):
"""CTC decoder
Args:
odim ([int]): text vocabulary size
enc_n_units ([int]): encoder output dimention
vocab_size ([int]): text vocabulary size
dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar
"""

@ -26,7 +26,7 @@ from deepspeech.modules.decoder_layer import DecoderLayer
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward
from deepspeech.modules.mask import subsequent_mask
from deepspeech.modules.mask import make_pad_mask
from deepspeech.modules.mask import make_non_pad_mask
logger = logging.getLogger(__name__)
@ -124,7 +124,9 @@ class TransformerDecoder(nn.Module):
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
# TODO(Hui Zhang): not support & for tensor
#tgt_mask = tgt_mask & m
tgt_mask = tgt_mask.logical_and(m)
x, _ = self.embed(tgt)
for layer in self.decoders:
@ -135,7 +137,9 @@ class TransformerDecoder(nn.Module):
if self.use_output_layer:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
#TODO(Hui Zhang): reduce_sum not support bool type
#olens = tgt_mask.sum(1)
olens = tgt_mask.astype(paddle.int).sum(1)
return x, olens
def forward_one_step(

@ -155,12 +155,15 @@ class BaseEncoder(nn.Layer):
encoder output tensor, lens and mask
"""
masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad = masks.logical_not()
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks = masks.astype(paddle.bool)
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad = masks.logical_not()
chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size,

@ -117,13 +117,12 @@ class LabelSmoothingLoss(nn.Layer):
B, T, D = paddle.shape(x)
assert D == self.size
x = x.reshape((-1, self.size))
target = target.reshape(-1)
target = target.reshape([-1])
# use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT
true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
ignore = ignore.cast(target.dtype)
#target = target * (1 - ignore) # avoid -1 index
target = target.masked_fill(ignore, 0) # avoid -1 index
@ -131,7 +130,9 @@ class LabelSmoothingLoss(nn.Layer):
kl = self.criterion(F.log_softmax(x, axis=1), true_dist)
total = len(target) - int(ignore.sum())
#TODO(Hui Zhang): sum not support bool type
#total = len(target) - int(ignore.sum())
total = len(target) - int(ignore.type_as(target).sum())
denom = total if self.normalize_length else B
#numer = (kl * (1 - ignore)).sum()
numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()

@ -97,6 +97,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
#TODO(Hui Zhang): return ~make_pad_mask(lengths), not support ~
return make_pad_mask(lengths).logical_not()
@ -119,7 +120,12 @@ def subsequent_mask(size: int) -> paddle.Tensor:
[1, 1, 1]]
"""
ret = paddle.ones([size, size], dtype=paddle.bool)
return paddle.tril(ret)
#TODO(Hui Zhang): tril not support bool
#return paddle.tril(ret)
ret = ret.astype(paddle.float)
ret = paddle.tril(ret)
ret = ret.astype(paddle.bool)
return ret
def subsequent_chunk_mask(

@ -115,14 +115,28 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
_sos = paddle.to_tensor(
[sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
_eos = paddle.to_tensor(
[eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B = ys_pad.size(0)
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
ys_in = paddle.cat([_sos, ys_pad], dim=1)
mask_pad = (ys_in == ignore_id)
ys_in = ys_in.masked_fill(mask_pad, eos)
ys_out = paddle.cat([ys_pad, _eos], dim=1)
ys_out = ys_out.masked_fill(mask_pad, eos)
mask_eos = (ys_in == ignore_id)
ys_out = ys_out.masked_fill(mask_eos, eos)
ys_out = ys_out.masked_fill(mask_pad, ignore_id)
return ys_in, ys_out
def th_accuracy(pad_outputs: paddle.Tensor,
@ -139,7 +153,13 @@ def th_accuracy(pad_outputs: paddle.Tensor,
pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = paddle.sum(
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator = (
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = paddle.sum(mask)
numerator = paddle.sum(numerator.type_as(pad_targets))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator = paddle.sum(mask.type_as(pad_targets))
return float(numerator) / float(denominator)

@ -86,8 +86,11 @@ class TestU2Model(unittest.TestCase):
cfg.freeze()
model = U2TransformerModel(cfg)
summary(model, None)
output = model(self.audio, self.audio_len, self.text, self.text_len)
print(output)
total_loss, attention_loss, ctc_loss = model(self.audio, self.audio_len,
self.text, self.text_len)
self.assertEqual(total_loss.numel(), 1)
self.assertEqual(attention_loss.numel(), 1)
self.assertEqual(ctc_loss.numel(), 1)
def test_conformer(self):
conf_str = """
@ -135,8 +138,11 @@ class TestU2Model(unittest.TestCase):
cfg.freeze()
model = U2ConformerModel(cfg)
summary(model, None)
output = model(self.audio, self.audio_len, self.text, self.text_len)
print(output)
total_loss, attention_loss, ctc_loss = model(self.audio, self.audio_len,
self.text, self.text_len)
self.assertEqual(total_loss.numel(), 1)
self.assertEqual(attention_loss.numel(), 1)
self.assertEqual(ctc_loss.numel(), 1)
if __name__ == '__main__':

Loading…
Cancel
Save