pull/578/head
Hui Zhang 5 years ago
parent 2fa6bbbed5
commit 5e7e582d7c

@ -22,7 +22,8 @@ import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
#TODO(Hui Zhang): remove fluid import
from paddle.fluid import core
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
########### hcak logging ############# ########### hcak logging #############
@ -44,10 +45,51 @@ paddle.int = 'int32'
paddle.int64 = 'int64' paddle.int64 = 'int64'
paddle.long = 'int64' paddle.long = 'int64'
paddle.uint8 = 'uint8' paddle.uint8 = 'uint8'
paddle.uint16 = 'uint16'
paddle.complex64 = 'complex64' paddle.complex64 = 'complex64'
paddle.complex128 = 'complex128' paddle.complex128 = 'complex128'
paddle.cdouble = 'complex128' paddle.cdouble = 'complex128'
def convert_dtype_to_string(tensor_dtype):
"""
Convert the data type in numpy to the data type in Paddle
Args:
tensor_dtype(core.VarDesc.VarType): the data type in numpy.
Returns:
core.VarDesc.VarType: the data type in Paddle.
"""
dtype = tensor_dtype
if dtype == core.VarDesc.VarType.FP32:
return paddle.float32
elif dtype == core.VarDesc.VarType.FP64:
return paddle.float64
elif dtype == core.VarDesc.VarType.FP16:
return paddle.float16
elif dtype == core.VarDesc.VarType.INT32:
return paddle.int32
elif dtype == core.VarDesc.VarType.INT16:
return paddle.int16
elif dtype == core.VarDesc.VarType.INT64:
return paddle.int64
elif dtype == core.VarDesc.VarType.BOOL:
return paddle.bool
elif dtype == core.VarDesc.VarType.BF16:
# since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
return paddle.uint16
elif dtype == core.VarDesc.VarType.UINT8:
return paddle.uint8
elif dtype == core.VarDesc.VarType.INT8:
return paddle.int8
elif dtype == core.VarDesc.VarType.COMPLEX64:
return paddle.complex64
elif dtype == core.VarDesc.VarType.COMPLEX128:
return paddle.complex128
else:
raise ValueError("Not supported tensor dtype %s" % dtype)
if not hasattr(paddle, 'softmax'): if not hasattr(paddle, 'softmax'):
logger.warn("register user softmax to paddle, remove this when fixed!") logger.warn("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'softmax', paddle.nn.functional.softmax) setattr(paddle, 'softmax', paddle.nn.functional.softmax)
@ -126,7 +168,9 @@ if not hasattr(paddle.Tensor, 'new_full'):
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
return xs.equal(paddle.to_tensor(ys, dtype=xs.dtype, place=xs.place)) return xs.equal(
paddle.to_tensor(
ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place))
if not hasattr(paddle.Tensor, 'eq'): if not hasattr(paddle.Tensor, 'eq'):
@ -184,10 +228,21 @@ if not hasattr(paddle.Tensor, 'view_as'):
paddle.Tensor.view_as = view_as paddle.Tensor.view_as = view_as
def is_broadcastable(shp1, shp2):
for a, b in zip(shp1[::-1], shp2[::-1]):
if a == 1 or b == 1 or a == b:
pass
else:
return False
return True
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
assert xs.shape == mask.shape assert is_broadcastable(xs.shape, mask.shape) == True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs) xs = paddle.where(mask, trues, xs)
return xs return xs
@ -202,7 +257,9 @@ if not hasattr(paddle.Tensor, 'masked_fill'):
def masked_fill_(xs: paddle.Tensor, def masked_fill_(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
assert xs.shape == mask.shape assert is_broadcastable(xs.shape, mask.shape) == True
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value trues = paddle.ones_like(xs) * value
ret = paddle.where(mask, trues, xs) ret = paddle.where(mask, trues, xs)
paddle.assign(ret, output=xs) paddle.assign(ret, output=xs)

@ -39,7 +39,7 @@ from deepspeech.modules.encoder import ConformerEncoder
from deepspeech.modules.encoder import TransformerEncoder from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.modules.ctc import CTCDecoder from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.decoder import TransformerDecoder from deepspeech.modules.decoder import TransformerDecoder
from deepspeech.modules.label_smoothing_loss import LabelSmoothingLoss from deepspeech.modules.loss import LabelSmoothingLoss
from deepspeech.frontend.utility import load_cmvn from deepspeech.frontend.utility import load_cmvn
@ -633,7 +633,7 @@ class U2Model(nn.Module):
class U2TransformerModel(U2Model): class U2TransformerModel(U2Model):
def __init__(configs: dict): def __init__(self, configs: dict):
if configs['cmvn_file'] is not None: if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'], mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type']) configs['cmvn_file_type'])
@ -655,7 +655,7 @@ class U2TransformerModel(U2Model):
**configs['decoder_conf']) **configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size()) ctc = CTCDecoder(vocab_size, encoder.output_size())
self.__init__( super().__init__(
vocab_size=vocab_size, vocab_size=vocab_size,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
@ -664,7 +664,7 @@ class U2TransformerModel(U2Model):
class U2ConformerModel(U2Model): class U2ConformerModel(U2Model):
def __init__(configs: dict): def __init__(self, configs: dict):
if configs['cmvn_file'] is not None: if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'], mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type']) configs['cmvn_file_type'])
@ -686,7 +686,7 @@ class U2ConformerModel(U2Model):
**configs['decoder_conf']) **configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size()) ctc = CTCDecoder(vocab_size, encoder.output_size())
self.__init__( super().__init__(
vocab_size=vocab_size, vocab_size=vocab_size,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,

@ -217,11 +217,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 # as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2) # (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) matrix_ac = paddle.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
# compute matrix b and matrix d # compute matrix b and matrix d
# (batch, head, time1, time2) # (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) matrix_bd = paddle.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
# Remove rel_shift since it is useless in speech recognition, # Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming. # and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd) # matrix_bd = self.rel_shift(matrix_bd)

@ -48,7 +48,7 @@ class PositionalEncoding(nn.Layer):
self.max_len = max_len self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
self.dropout = nn.Dropout(p=dropout_rate) self.dropout = nn.Dropout(p=dropout_rate)
self.pe = paddle.zeros(self.max_len, self.d_model) #[T,D] self.pe = paddle.zeros([self.max_len, self.d_model]) #[T,D]
position = paddle.arange( position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] 0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
@ -70,11 +70,9 @@ class PositionalEncoding(nn.Layer):
paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...) paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...)
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...) paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
""" """
T = paddle.shape(x)[1] T = x.shape[1]
assert offset + T < self.max_len assert offset + x.size(1) < self.max_len
#assert offset + x.size(1) < self.max_len #TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor
#self.pe = self.pe.to(x.device)
#pos_emb = self.pe[:, offset:offset + x.size(1)]
pos_emb = self.pe[:, offset:offset + T] pos_emb = self.pe[:, offset:offset + T]
x = x * self.xscale + pos_emb x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
@ -119,11 +117,8 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`).
""" """
#T = paddle.shape()[1]
#assert offset + T < self.max_len
assert offset + x.size(1) < self.max_len assert offset + x.size(1) < self.max_len
#self.pe = self.pe.to(x.device)
x = x * self.xscale x = x * self.xscale
pos_emb = self.pe[:, offset:offset + x.size(1)] #TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
#pos_emb = self.pe[:, offset:offset + T] pos_emb = self.pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)

@ -23,7 +23,7 @@ from paddle.nn import initializer as I
from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.attention import MultiHeadedAttention
from deepspeech.modules.attention import RelPositionMultiHeadedAttention from deepspeech.modules.attention import RelPositionMultiHeadedAttention
from deepspeech.modules.convolution import ConvolutionModule from deepspeech.modules.conformer_convolution import ConvolutionModule
from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.embedding import RelPositionalEncoding from deepspeech.modules.embedding import RelPositionalEncoding
from deepspeech.modules.encoder_layer import TransformerEncoderLayer from deepspeech.modules.encoder_layer import TransformerEncoderLayer
@ -33,7 +33,7 @@ from deepspeech.modules.subsampling import Conv2dSubsampling4
from deepspeech.modules.subsampling import Conv2dSubsampling6 from deepspeech.modules.subsampling import Conv2dSubsampling6
from deepspeech.modules.subsampling import Conv2dSubsampling8 from deepspeech.modules.subsampling import Conv2dSubsampling8
from deepspeech.modules.subsampling import LinearNoSubsampling from deepspeech.modules.subsampling import LinearNoSubsampling
from deepspeech.modules.mask import make_pad_mask from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.modules.mask import add_optional_chunk_mask from deepspeech.modules.mask import add_optional_chunk_mask
from deepspeech.modules.activation import get_activation from deepspeech.modules.activation import get_activation
@ -155,10 +155,12 @@ class BaseEncoder(nn.Layer):
encoder output tensor, lens and mask encoder output tensor, lens and mask
""" """
masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L) masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad = masks.logical_not()
if self.global_cmvn is not None: if self.global_cmvn is not None:
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks, offset=0) #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
mask_pad = ~masks xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)
chunk_masks = add_optional_chunk_mask( chunk_masks = add_optional_chunk_mask(
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
decoding_chunk_size, self.static_chunk_size, decoding_chunk_size, self.static_chunk_size,
@ -380,7 +382,7 @@ class ConformerEncoder(BaseEncoder):
concat_after: bool=False, concat_after: bool=False,
static_chunk_size: int=0, static_chunk_size: int=0,
use_dynamic_chunk: bool=False, use_dynamic_chunk: bool=False,
global_cmvn: torch.nn.Module=None, global_cmvn: nn.Layer=None,
use_dynamic_left_chunk: bool=False, use_dynamic_left_chunk: bool=False,
positionwise_conv_kernel_size: int=1, positionwise_conv_kernel_size: int=1,
macaron_style: bool=True, macaron_style: bool=True,
@ -431,7 +433,7 @@ class ConformerEncoder(BaseEncoder):
self.encoders = nn.ModuleList([ self.encoders = nn.ModuleList([
ConformerEncoderLayer( ConformerEncoderLayer(
size=output_size, size=output_size,
eself_attn=ncoder_selfattn_layer(*encoder_selfattn_layer_args), self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),
feed_forward=positionwise_layer(*positionwise_layer_args), feed_forward=positionwise_layer(*positionwise_layer_args),
feed_forward_macaron=positionwise_layer( feed_forward_macaron=positionwise_layer(
*positionwise_layer_args) if macaron_style else None, *positionwise_layer_args) if macaron_style else None,

@ -127,7 +127,7 @@ class TransformerEncoderLayer(nn.Layer):
if output_cache is not None: if output_cache is not None:
x = paddle.concat([output_cache, x], axis=1) x = paddle.concat([output_cache, x], axis=1)
fake_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place) fake_cnn_cache = paddle.zeros([1], dtype=x.dtype)
return x, mask, fake_cnn_cache return x, mask, fake_cnn_cache
@ -253,7 +253,7 @@ class ConformerEncoderLayer(nn.Layer):
# convolution module # convolution module
# Fake new cnn cache here, and then change it in conv_module # Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = paddle.to_tensor([0.0], dtype=x.dtype, place=x.place) new_cnn_cache = paddle.zeros([1], dtype=x.dtype)
if self.conv_module is not None: if self.conv_module is not None:
residual = x residual = x
if self.normalize_before: if self.normalize_before:

@ -24,13 +24,18 @@ __all__ = [
def summary(layer: nn.Layer, print_func=print): def summary(layer: nn.Layer, print_func=print):
num_params = num_elements = 0 num_params = num_elements = 0
print_func("layer summary:") if print_func:
print_func(f"{layer.__class__.__name__} summary:")
for name, param in layer.state_dict().items(): for name, param in layer.state_dict().items():
print_func("{}|{}|{}".format(name, param.shape, np.prod(param.shape))) if print_func:
print_func(
"{} | {} | {}".format(name, param.shape, np.prod(param.shape)))
num_elements += np.prod(param.shape) num_elements += np.prod(param.shape)
num_params += 1 num_params += 1
print_func("layer has {} parameters, {} elements.".format(num_params, if print_func:
num_elements)) print_func(
f"{layer.__class__.__name__} has {num_params} parameters, {num_elements} elements."
)
def gradient_norm(layer: nn.Layer): def gradient_norm(layer: nn.Layer):

@ -122,7 +122,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
def th_accuracy(pad_outputs: paddle.Tensor, def th_accuracy(pad_outputs: paddle.Tensor,

@ -20,6 +20,7 @@ from yacs.config import CfgNode as CN
from deepspeech.models.u2 import U2TransformerModel from deepspeech.models.u2 import U2TransformerModel
from deepspeech.models.u2 import U2ConformerModel from deepspeech.models.u2 import U2ConformerModel
from deepspeech.utils.layer_tools import summary
class TestU2Model(unittest.TestCase): class TestU2Model(unittest.TestCase):
@ -27,8 +28,9 @@ class TestU2Model(unittest.TestCase):
paddle.set_device('cpu') paddle.set_device('cpu')
self.batch_size = 2 self.batch_size = 2
self.feat_dim = 161 self.feat_dim = 83
self.max_len = 64 self.max_len = 64
self.vocab_size = 4239
#(B, T, D) #(B, T, D)
audio = np.random.randn(self.batch_size, self.max_len, self.feat_dim) audio = np.random.randn(self.batch_size, self.max_len, self.feat_dim)
@ -77,8 +79,15 @@ class TestU2Model(unittest.TestCase):
length_normalized_loss: false length_normalized_loss: false
""" """
cfg = CN().load_cfg(conf_str) cfg = CN().load_cfg(conf_str)
print(cfg) cfg.input_dim = self.feat_dim
model = U2TransformerModel() cfg.output_dim = self.vocab_size
cfg.cmvn_file = None
cfg.cmvn_file_type = 'npz'
cfg.freeze()
model = U2TransformerModel(cfg)
summary(model, None)
output = model(self.audio, self.audio_len, self.text, self.text_len)
print(output)
def test_conformer(self): def test_conformer(self):
conf_str = """ conf_str = """
@ -119,8 +128,15 @@ class TestU2Model(unittest.TestCase):
length_normalized_loss: false length_normalized_loss: false
""" """
cfg = CN().load_cfg(conf_str) cfg = CN().load_cfg(conf_str)
print(cfg) cfg.input_dim = self.feat_dim
model = U2ConformerModel() cfg.output_dim = self.vocab_size
cfg.cmvn_file = None
cfg.cmvn_file_type = 'npz'
cfg.freeze()
model = U2ConformerModel(cfg)
summary(model, None)
output = model(self.audio, self.audio_len, self.text, self.text_len)
print(output)
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save