pull/578/head
Hui Zhang 5 years ago
parent 2fa6bbbed5
commit 5e7e582d7c

@ -22,7 +22,8 @@ import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
#TODO(Hui Zhang): remove fluid import
from paddle.fluid import core
logger = logging.getLogger(__name__)
########### hcak logging #############
@ -44,10 +45,51 @@ paddle.int = 'int32'
paddle.int64 = 'int64'
paddle.long = 'int64'
paddle.uint8 = 'uint8'
paddle.uint16 = 'uint16'
paddle.complex64 = 'complex64'
paddle.complex128 = 'complex128'
paddle.cdouble = 'complex128'
def convert_dtype_to_string(tensor_dtype):
"""
Convert the data type in numpy to the data type in Paddle
Args:
tensor_dtype(core.VarDesc.VarType): the data type in numpy.
Returns:
core.VarDesc.VarType: the data type in Paddle.
"""
dtype = tensor_dtype
if dtype == core.VarDesc.VarType.FP32:
return paddle.float32
elif dtype == core.VarDesc.VarType.FP64:
return paddle.float64
elif dtype == core.VarDesc.VarType.FP16:
return paddle.float16
elif dtype == core.VarDesc.VarType.INT32:
return paddle.int32
elif dtype == core.VarDesc.VarType.INT16:
return paddle.int16
elif dtype == core.VarDesc.VarType.INT64:
return paddle.int64
elif dtype == core.VarDesc.VarType.BOOL:
return paddle.bool
elif dtype == core.VarDesc.VarType.BF16:
# since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
return paddle.uint16
elif dtype == core.VarDesc.VarType.UINT8:
return paddle.uint8
elif dtype == core.VarDesc.VarType.INT8:
return paddle.int8
elif dtype == core.VarDesc.VarType.COMPLEX64:
return paddle.complex64
elif dtype == core.VarDesc.VarType.COMPLEX128:
return paddle.complex128
else:
raise ValueError("Not supported tensor dtype %s" % dtype)
if not hasattr(paddle, 'softmax'):
logger.warn("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'softmax', paddle.nn.functional.softmax)
@ -126,7 +168,9 @@ if not hasattr(paddle.Tensor, 'new_full'):
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
return xs.equal(paddle.to_tensor(ys, dtype=xs.dtype, place=xs.place))
return xs.equal(
paddle.to_tensor(
ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place))
if not hasattr(paddle.Tensor, 'eq'):
@ -184,10 +228,21 @@ if not hasattr(paddle.Tensor, 'view_as'):
paddle.Tensor.view_as = view_as
def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b:
pass
else:
return False
return True
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert xs.shape == mask.shape
assert is_broadcastable(xs.shape, mask.shape) == True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
@ -202,7 +257,9 @@ if not hasattr(paddle.Tensor, 'masked_fill'):
def masked_fill_(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert xs.shape == mask.shape
assert is_broadcastable(xs.shape, mask.shape) == True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
ret = paddle.where(mask, trues, xs)
paddle.assign(ret, output=xs)

@ -39,7 +39,7 @@ from deepspeech.modules.encoder import ConformerEncoder
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.decoder import TransformerDecoder
from deepspeech.modules.label_smoothing_loss import LabelSmoothingLoss
from deepspeech.modules.loss import LabelSmoothingLoss
from deepspeech.frontend.utility import load_cmvn
@ -633,7 +633,7 @@ class U2Model(nn.Module):
class U2TransformerModel(U2Model):
def __init__(configs: dict):
def __init__(self, configs: dict):
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type'])
@ -655,7 +655,7 @@ class U2TransformerModel(U2Model):
**configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size())
self.__init__(
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
@ -664,7 +664,7 @@ class U2TransformerModel(U2Model):
class U2ConformerModel(U2Model):
def __init__(configs: dict):
def __init__(self, configs: dict):
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type'])
@ -686,7 +686,7 @@ class U2ConformerModel(U2Model):
**configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size())
self.__init__(
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,

@ -217,11 +217,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)

@ -48,7 +48,7 @@ class PositionalEncoding(nn.Layer):
self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
self.dropout = nn.Dropout(p=dropout_rate)
self.pe = paddle.zeros(self.max_len, self.d_model) #[T,D]
self.pe = paddle.zeros([self.max_len, self.d_model]) #[T,D]
position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
@ -70,11 +70,9 @@ class PositionalEncoding(nn.Layer):
paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...)
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
"""
T = paddle.shape(x)[1]
assert offset + T < self.max_len
#assert offset + x.size(1) < self.max_len
#self.pe = self.pe.to(x.device)
#pos_emb = self.pe[:, offset:offset + x.size(1)]
T = x.shape[1]
assert offset + x.size(1) < self.max_len
#TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + T]
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
@ -119,11 +117,8 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
#T = paddle.shape()[1]
#assert offset + T < self.max_len
assert offset + x.size(1) < self.max_len
#self.pe = self.pe.to(x.device)
x = x * self.xscale
pos_emb = self.pe[:, offset:offset + x.size(1)]
#pos_emb = self.pe[:, offset:offset + T]
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb)

@ -23,7 +23,7 @@ from paddle.nn import initializer as I
from deepspeech.modules.attention import MultiHeadedAttention
from deepspeech.modules.attention import RelPositionMultiHeadedAttention
from deepspeech.modules.convolution import ConvolutionModule
from deepspeech.modules.conformer_convolution import ConvolutionModule
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.embedding import RelPositionalEncoding
from deepspeech.modules.encoder_layer import TransformerEncoderLayer
@ -33,7 +33,7 @@ from deepspeech.modules.subsampling import Conv2dSubsampling4
from deepspeech.modules.subsampling import Conv2dSubsampling6
from deepspeech.modules.subsampling import Conv2dSubsampling8
from deepspeech.modules.subsampling import LinearNoSubsampling
from deepspeech.modules.mask import make_pad_mask
from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.modules.mask import add_optional_chunk_mask
from deepspeech.modules.activation import get_activation
@ -155,10 +155,12 @@ class BaseEncoder(nn.Layer):
encoder output tensor, lens and mask
"""
masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad = masks.logical_not()
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks, offset=0)
mask_pad = ~masks
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)
chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size,
@ -380,7 +382,7 @@ class ConformerEncoder(BaseEncoder):
concat_after: bool=False,
static_chunk_size: int=0,
use_dynamic_chunk: bool=False,
global_cmvn: torch.nn.Module=None,
global_cmvn: nn.Layer=None,
use_dynamic_left_chunk: bool=False,
positionwise_conv_kernel_size: int=1,
macaron_style: bool=True,
@ -431,7 +433,7 @@ class ConformerEncoder(BaseEncoder):
self.encoders = nn.ModuleList([
ConformerEncoderLayer(
size=output_size,
eself_attn=ncoder_selfattn_layer(*encoder_selfattn_layer_args),
self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),
feed_forward=positionwise_layer(*positionwise_layer_args),
feed_forward_macaron=positionwise_layer(
*positionwise_layer_args) if macaron_style else None,

@ -127,7 +127,7 @@ class TransformerEncoderLayer(nn.Layer):
if output_cache is not None:
x = paddle.concat([output_cache, x], axis=1)
fake_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place)
fake_cnn_cache = paddle.zeros([1], dtype=x.dtype)
return x, mask, fake_cnn_cache
@ -253,7 +253,7 @@ class ConformerEncoderLayer(nn.Layer):
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place)
new_cnn_cache = paddle.zeros([1], dtype=x.dtype)
if self.conv_module is not None:
residual = x
if self.normalize_before:

@ -24,13 +24,18 @@ __all__ = [
def summary(layer: nn.Layer, print_func=print):
num_params = num_elements = 0
print_func("layer summary:")
if print_func:
print_func(f"{layer.__class__.__name__} summary:")
for name, param in layer.state_dict().items():
print_func("{}|{}|{}".format(name, param.shape, np.prod(param.shape)))
if print_func:
print_func(
"{} | {} | {}".format(name, param.shape, np.prod(param.shape)))
num_elements += np.prod(param.shape)
num_params += 1
print_func("layer has {} parameters, {} elements.".format(num_params,
num_elements))
if print_func:
print_func(
f"{layer.__class__.__name__} has {num_params} parameters, {num_elements} elements."
)
def gradient_norm(layer: nn.Layer):

@ -122,7 +122,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
def th_accuracy(pad_outputs: paddle.Tensor,

@ -20,6 +20,7 @@ from yacs.config import CfgNode as CN
from deepspeech.models.u2 import U2TransformerModel
from deepspeech.models.u2 import U2ConformerModel
from deepspeech.utils.layer_tools import summary
class TestU2Model(unittest.TestCase):
@ -27,8 +28,9 @@ class TestU2Model(unittest.TestCase):
paddle.set_device('cpu')
self.batch_size = 2
self.feat_dim = 161
self.feat_dim = 83
self.max_len = 64
self.vocab_size = 4239
#(B, T, D)
audio = np.random.randn(self.batch_size, self.max_len, self.feat_dim)
@ -77,8 +79,15 @@ class TestU2Model(unittest.TestCase):
length_normalized_loss: false
"""
cfg = CN().load_cfg(conf_str)
print(cfg)
model = U2TransformerModel()
cfg.input_dim = self.feat_dim
cfg.output_dim = self.vocab_size
cfg.cmvn_file = None
cfg.cmvn_file_type = 'npz'
cfg.freeze()
model = U2TransformerModel(cfg)
summary(model, None)
output = model(self.audio, self.audio_len, self.text, self.text_len)
print(output)
def test_conformer(self):
conf_str = """
@ -119,8 +128,15 @@ class TestU2Model(unittest.TestCase):
length_normalized_loss: false
"""
cfg = CN().load_cfg(conf_str)
print(cfg)
model = U2ConformerModel()
cfg.input_dim = self.feat_dim
cfg.output_dim = self.vocab_size
cfg.cmvn_file = None
cfg.cmvn_file_type = 'npz'
cfg.freeze()
model = U2ConformerModel(cfg)
summary(model, None)
output = model(self.audio, self.audio_len, self.text, self.text_len)
print(output)
if __name__ == '__main__':

Loading…
Cancel
Save