From 5e7e582d7c8e5c968a460acf4d7527c525348130 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 7 Apr 2021 07:38:49 +0000 Subject: [PATCH] fix bugs --- deepspeech/__init__.py | 67 ++++++++++++++++++++++++++--- deepspeech/models/u2.py | 10 ++--- deepspeech/modules/attention.py | 4 +- deepspeech/modules/embedding.py | 17 +++----- deepspeech/modules/encoder.py | 14 +++--- deepspeech/modules/encoder_layer.py | 4 +- deepspeech/utils/layer_tools.py | 13 ++++-- deepspeech/utils/tensor_utils.py | 2 +- tests/u2_model_test.py | 26 ++++++++--- 9 files changed, 116 insertions(+), 41 deletions(-) diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index ab5f0e137..b9cc2ca27 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -22,7 +22,8 @@ import paddle from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I - +#TODO(Hui Zhang): remove fluid import +from paddle.fluid import core logger = logging.getLogger(__name__) ########### hcak logging ############# @@ -44,10 +45,51 @@ paddle.int = 'int32' paddle.int64 = 'int64' paddle.long = 'int64' paddle.uint8 = 'uint8' +paddle.uint16 = 'uint16' paddle.complex64 = 'complex64' paddle.complex128 = 'complex128' paddle.cdouble = 'complex128' + +def convert_dtype_to_string(tensor_dtype): + """ + Convert the data type in numpy to the data type in Paddle + Args: + tensor_dtype(core.VarDesc.VarType): the data type in numpy. + Returns: + core.VarDesc.VarType: the data type in Paddle. + """ + dtype = tensor_dtype + if dtype == core.VarDesc.VarType.FP32: + return paddle.float32 + elif dtype == core.VarDesc.VarType.FP64: + return paddle.float64 + elif dtype == core.VarDesc.VarType.FP16: + return paddle.float16 + elif dtype == core.VarDesc.VarType.INT32: + return paddle.int32 + elif dtype == core.VarDesc.VarType.INT16: + return paddle.int16 + elif dtype == core.VarDesc.VarType.INT64: + return paddle.int64 + elif dtype == core.VarDesc.VarType.BOOL: + return paddle.bool + elif dtype == core.VarDesc.VarType.BF16: + # since there is still no support for bfloat16 in NumPy, + # uint16 is used for casting bfloat16 + return paddle.uint16 + elif dtype == core.VarDesc.VarType.UINT8: + return paddle.uint8 + elif dtype == core.VarDesc.VarType.INT8: + return paddle.int8 + elif dtype == core.VarDesc.VarType.COMPLEX64: + return paddle.complex64 + elif dtype == core.VarDesc.VarType.COMPLEX128: + return paddle.complex128 + else: + raise ValueError("Not supported tensor dtype %s" % dtype) + + if not hasattr(paddle, 'softmax'): logger.warn("register user softmax to paddle, remove this when fixed!") setattr(paddle, 'softmax', paddle.nn.functional.softmax) @@ -126,7 +168,9 @@ if not hasattr(paddle.Tensor, 'new_full'): def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: - return xs.equal(paddle.to_tensor(ys, dtype=xs.dtype, place=xs.place)) + return xs.equal( + paddle.to_tensor( + ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place)) if not hasattr(paddle.Tensor, 'eq'): @@ -184,10 +228,21 @@ if not hasattr(paddle.Tensor, 'view_as'): paddle.Tensor.view_as = view_as +def is_broadcastable(shp1, shp2): + for a, b in zip(shp1[::-1], shp2[::-1]): + if a == 1 or b == 1 or a == b: + pass + else: + return False + return True + + def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - assert xs.shape == mask.shape + assert is_broadcastable(xs.shape, mask.shape) == True + bshape = paddle.broadcast_shape(xs.shape, mask.shape) + mask = mask.broadcast_to(bshape) trues = paddle.ones_like(xs) * value xs = paddle.where(mask, trues, xs) return xs @@ -202,7 +257,9 @@ if not hasattr(paddle.Tensor, 'masked_fill'): def masked_fill_(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - assert xs.shape == mask.shape + assert is_broadcastable(xs.shape, mask.shape) == True + bshape = paddle.broadcast_shape(xs.shape, mask.shape) + mask = mask.broadcast_to(bshape) trues = paddle.ones_like(xs) * value ret = paddle.where(mask, trues, xs) paddle.assign(ret, output=xs) @@ -414,4 +471,4 @@ if not hasattr(paddle.nn, 'ConstantPad2d'): if not hasattr(paddle.jit, 'export'): logger.warn("register user export to paddle.jit, remove this when fixed!") - setattr(paddle.jit, 'export', paddle.jit.to_static) \ No newline at end of file + setattr(paddle.jit, 'export', paddle.jit.to_static) diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 9ecbc0177..d76dc76a6 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -39,7 +39,7 @@ from deepspeech.modules.encoder import ConformerEncoder from deepspeech.modules.encoder import TransformerEncoder from deepspeech.modules.ctc import CTCDecoder from deepspeech.modules.decoder import TransformerDecoder -from deepspeech.modules.label_smoothing_loss import LabelSmoothingLoss +from deepspeech.modules.loss import LabelSmoothingLoss from deepspeech.frontend.utility import load_cmvn @@ -633,7 +633,7 @@ class U2Model(nn.Module): class U2TransformerModel(U2Model): - def __init__(configs: dict): + def __init__(self, configs: dict): if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], configs['cmvn_file_type']) @@ -655,7 +655,7 @@ class U2TransformerModel(U2Model): **configs['decoder_conf']) ctc = CTCDecoder(vocab_size, encoder.output_size()) - self.__init__( + super().__init__( vocab_size=vocab_size, encoder=encoder, decoder=decoder, @@ -664,7 +664,7 @@ class U2TransformerModel(U2Model): class U2ConformerModel(U2Model): - def __init__(configs: dict): + def __init__(self, configs: dict): if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], configs['cmvn_file_type']) @@ -686,7 +686,7 @@ class U2ConformerModel(U2Model): **configs['decoder_conf']) ctc = CTCDecoder(vocab_size, encoder.output_size()) - self.__init__( + super().__init__( vocab_size=vocab_size, encoder=encoder, decoder=decoder, diff --git a/deepspeech/modules/attention.py b/deepspeech/modules/attention.py index f9a91b94e..e9336c033 100644 --- a/deepspeech/modules/attention.py +++ b/deepspeech/modules/attention.py @@ -217,11 +217,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): # first compute matrix a and matrix c # as described in https://arxiv.org/abs/1901.02860 Section 3.3 # (batch, head, time1, time2) - matrix_ac = torch.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) + matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) # compute matrix b and matrix d # (batch, head, time1, time2) - matrix_bd = torch.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) + matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) # Remove rel_shift since it is useless in speech recognition, # and it requires special attention for streaming. # matrix_bd = self.rel_shift(matrix_bd) diff --git a/deepspeech/modules/embedding.py b/deepspeech/modules/embedding.py index efefd75ac..4746e1d04 100644 --- a/deepspeech/modules/embedding.py +++ b/deepspeech/modules/embedding.py @@ -48,7 +48,7 @@ class PositionalEncoding(nn.Layer): self.max_len = max_len self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.dropout = nn.Dropout(p=dropout_rate) - self.pe = paddle.zeros(self.max_len, self.d_model) #[T,D] + self.pe = paddle.zeros([self.max_len, self.d_model]) #[T,D] position = paddle.arange( 0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] @@ -70,11 +70,9 @@ class PositionalEncoding(nn.Layer): paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...) paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...) """ - T = paddle.shape(x)[1] - assert offset + T < self.max_len - #assert offset + x.size(1) < self.max_len - #self.pe = self.pe.to(x.device) - #pos_emb = self.pe[:, offset:offset + x.size(1)] + T = x.shape[1] + assert offset + x.size(1) < self.max_len + #TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + T] x = x * self.xscale + pos_emb return self.dropout(x), self.dropout(pos_emb) @@ -119,11 +117,8 @@ class RelPositionalEncoding(PositionalEncoding): paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`). """ - #T = paddle.shape()[1] - #assert offset + T < self.max_len assert offset + x.size(1) < self.max_len - #self.pe = self.pe.to(x.device) x = x * self.xscale - pos_emb = self.pe[:, offset:offset + x.size(1)] - #pos_emb = self.pe[:, offset:offset + T] + #TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor + pos_emb = self.pe[:, offset:offset + x.shape[1]] return self.dropout(x), self.dropout(pos_emb) diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index 9a4017fec..73829b75a 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -23,7 +23,7 @@ from paddle.nn import initializer as I from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.attention import RelPositionMultiHeadedAttention -from deepspeech.modules.convolution import ConvolutionModule +from deepspeech.modules.conformer_convolution import ConvolutionModule from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.embedding import RelPositionalEncoding from deepspeech.modules.encoder_layer import TransformerEncoderLayer @@ -33,7 +33,7 @@ from deepspeech.modules.subsampling import Conv2dSubsampling4 from deepspeech.modules.subsampling import Conv2dSubsampling6 from deepspeech.modules.subsampling import Conv2dSubsampling8 from deepspeech.modules.subsampling import LinearNoSubsampling -from deepspeech.modules.mask import make_pad_mask +from deepspeech.modules.mask import make_non_pad_mask from deepspeech.modules.mask import add_optional_chunk_mask from deepspeech.modules.activation import get_activation @@ -155,10 +155,12 @@ class BaseEncoder(nn.Layer): encoder output tensor, lens and mask """ masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L) + #TODO(Hui Zhang): mask_pad = ~masks + mask_pad = masks.logical_not() if self.global_cmvn is not None: xs = self.global_cmvn(xs) - xs, pos_emb, masks = self.embed(xs, masks, offset=0) - mask_pad = ~masks + #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor + xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0) chunk_masks = add_optional_chunk_mask( xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, self.static_chunk_size, @@ -380,7 +382,7 @@ class ConformerEncoder(BaseEncoder): concat_after: bool=False, static_chunk_size: int=0, use_dynamic_chunk: bool=False, - global_cmvn: torch.nn.Module=None, + global_cmvn: nn.Layer=None, use_dynamic_left_chunk: bool=False, positionwise_conv_kernel_size: int=1, macaron_style: bool=True, @@ -431,7 +433,7 @@ class ConformerEncoder(BaseEncoder): self.encoders = nn.ModuleList([ ConformerEncoderLayer( size=output_size, - eself_attn=ncoder_selfattn_layer(*encoder_selfattn_layer_args), + self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args), feed_forward=positionwise_layer(*positionwise_layer_args), feed_forward_macaron=positionwise_layer( *positionwise_layer_args) if macaron_style else None, diff --git a/deepspeech/modules/encoder_layer.py b/deepspeech/modules/encoder_layer.py index 2828f0053..d00e9f0a0 100644 --- a/deepspeech/modules/encoder_layer.py +++ b/deepspeech/modules/encoder_layer.py @@ -127,7 +127,7 @@ class TransformerEncoderLayer(nn.Layer): if output_cache is not None: x = paddle.concat([output_cache, x], axis=1) - fake_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place) + fake_cnn_cache = paddle.zeros([1], dtype=x.dtype) return x, mask, fake_cnn_cache @@ -253,7 +253,7 @@ class ConformerEncoderLayer(nn.Layer): # convolution module # Fake new cnn cache here, and then change it in conv_module - new_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place) + new_cnn_cache = paddle.zeros([1], dtype=x.dtype) if self.conv_module is not None: residual = x if self.normalize_before: diff --git a/deepspeech/utils/layer_tools.py b/deepspeech/utils/layer_tools.py index 20c8ccf60..e3350dced 100644 --- a/deepspeech/utils/layer_tools.py +++ b/deepspeech/utils/layer_tools.py @@ -24,13 +24,18 @@ __all__ = [ def summary(layer: nn.Layer, print_func=print): num_params = num_elements = 0 - print_func("layer summary:") + if print_func: + print_func(f"{layer.__class__.__name__} summary:") for name, param in layer.state_dict().items(): - print_func("{}|{}|{}".format(name, param.shape, np.prod(param.shape))) + if print_func: + print_func( + "{} | {} | {}".format(name, param.shape, np.prod(param.shape))) num_elements += np.prod(param.shape) num_params += 1 - print_func("layer has {} parameters, {} elements.".format(num_params, - num_elements)) + if print_func: + print_func( + f"{layer.__class__.__name__} has {num_params} parameters, {num_elements} elements." + ) def gradient_norm(layer: nn.Layer): diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 9f67c1a61..ab6dbfbc5 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -122,7 +122,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] - return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) + return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id) def th_accuracy(pad_outputs: paddle.Tensor, diff --git a/tests/u2_model_test.py b/tests/u2_model_test.py index c8c547dd5..54a4dee14 100644 --- a/tests/u2_model_test.py +++ b/tests/u2_model_test.py @@ -20,6 +20,7 @@ from yacs.config import CfgNode as CN from deepspeech.models.u2 import U2TransformerModel from deepspeech.models.u2 import U2ConformerModel +from deepspeech.utils.layer_tools import summary class TestU2Model(unittest.TestCase): @@ -27,8 +28,9 @@ class TestU2Model(unittest.TestCase): paddle.set_device('cpu') self.batch_size = 2 - self.feat_dim = 161 + self.feat_dim = 83 self.max_len = 64 + self.vocab_size = 4239 #(B, T, D) audio = np.random.randn(self.batch_size, self.max_len, self.feat_dim) @@ -77,8 +79,15 @@ class TestU2Model(unittest.TestCase): length_normalized_loss: false """ cfg = CN().load_cfg(conf_str) - print(cfg) - model = U2TransformerModel() + cfg.input_dim = self.feat_dim + cfg.output_dim = self.vocab_size + cfg.cmvn_file = None + cfg.cmvn_file_type = 'npz' + cfg.freeze() + model = U2TransformerModel(cfg) + summary(model, None) + output = model(self.audio, self.audio_len, self.text, self.text_len) + print(output) def test_conformer(self): conf_str = """ @@ -119,8 +128,15 @@ class TestU2Model(unittest.TestCase): length_normalized_loss: false """ cfg = CN().load_cfg(conf_str) - print(cfg) - model = U2ConformerModel() + cfg.input_dim = self.feat_dim + cfg.output_dim = self.vocab_size + cfg.cmvn_file = None + cfg.cmvn_file_type = 'npz' + cfg.freeze() + model = U2ConformerModel(cfg) + summary(model, None) + output = model(self.audio, self.audio_len, self.text, self.text_len) + print(output) if __name__ == '__main__':