From 7635f98bce89fc4955f84ba202742f94582c83e2 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 18 Mar 2021 12:47:14 +0000 Subject: [PATCH] add attention, common utils, hack paddle --- deepspeech/__init__.py | 304 +++++++++++++++++++++++++++++++ deepspeech/modules/__init__.py | 246 +------------------------ deepspeech/modules/activation.py | 44 ++--- deepspeech/modules/attention.py | 85 +++++---- deepspeech/modules/embedding.py | 13 +- deepspeech/utils/common.py | 113 ++++++++++++ deepspeech/utils/layer_tools.py | 3 +- deepspeech/utils/metric.py | 43 +++++ tests/network_test.py | 26 +-- 9 files changed, 539 insertions(+), 338 deletions(-) create mode 100644 deepspeech/utils/common.py create mode 100644 deepspeech/utils/metric.py diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 185a92b8d..563746f41 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -11,3 +11,307 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import Union +from typing import Any + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +logger = logging.getLogger(__name__) + +########### hcak logging ############# +logger.warn = logging.warning + +########### hcak paddle ############# +paddle.bool = 'bool' +paddle.float16 = 'float16' +paddle.half = 'float16' +paddle.float32 = 'float32' +paddle.float = 'float32' +paddle.float64 = 'float64' +paddle.double = 'float64' +paddle.int8 = 'int8' +paddle.int16 = 'int16' +paddle.short = 'int16' +paddle.int32 = 'int32' +paddle.int = 'int32' +paddle.int64 = 'int64' +paddle.long = 'int64' +paddle.uint8 = 'uint8' +paddle.complex64 = 'complex64' +paddle.complex128 = 'complex128' +paddle.cdouble = 'complex128' + +if not hasattr(paddle, 'softmax'): + logger.warn("register user softmax to paddle, remove this when fixed!") + setattr(paddle, 'softmax', paddle.nn.functional.softmax) + +if not hasattr(paddle, 'sigmoid'): + logger.warn("register user sigmoid to paddle, remove this when fixed!") + setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) + +if not hasattr(paddle, 'relu'): + logger.warn("register user relu to paddle, remove this when fixed!") + setattr(paddle, 'relu', paddle.nn.functional.relu) + + +def cat(xs, dim=0): + return paddle.concat(xs, axis=dim) + + +if not hasattr(paddle, 'cat'): + logger.warn( + "override cat of paddle if exists or register, remove this when fixed!") + paddle.cat = cat + +########### hcak paddle.Tensor ############# +if not hasattr(paddle.Tensor, 'numel'): + logger.warn( + "override numel of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.numel = paddle.numel + + +def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: + return xs.equal(paddle.to_tensor(ys, dtype=xs.dtype, place=xs.place)) + + +if not hasattr(paddle.Tensor, 'eq'): + logger.warn( + "override eq of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.eq = eq + + +def contiguous(xs: paddle.Tensor) -> paddle.Tensor: + return xs + + +if not hasattr(paddle.Tensor, 'contiguous'): + logger.warn( + "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.contiguous = contiguous + + +def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: + nargs = len(args) + assert (nargs <= 1) + s = paddle.shape(xs) + if nargs == 1: + return s[args[0]] + else: + return s + + +#`to_static` do not process `size` property, maybe some `paddle` api dependent on it. +logger.warn( + "override size of paddle.Tensor " + "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" +) +paddle.Tensor.size = size + + +def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: + return xs.reshape(args) + + +if not hasattr(paddle.Tensor, 'view'): + logger.warn("register user view to paddle.Tensor, remove this when fixed!") + paddle.Tensor.view = view + + +def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: + return xs.reshape(ys.size()) + + +if not hasattr(paddle.Tensor, 'view_as'): + logger.warn( + "register user view_as to paddle.Tensor, remove this when fixed!") + paddle.Tensor.view_as = view_as + + +def masked_fill(xs: paddle.Tensor, + mask: paddle.Tensor, + value: Union[float, int]): + assert xs.shape == mask.shape + trues = paddle.ones_like(xs) * value + xs = paddle.where(mask, trues, xs) + return xs + + +if not hasattr(paddle.Tensor, 'masked_fill'): + logger.warn( + "register user masked_fill to paddle.Tensor, remove this when fixed!") + paddle.Tensor.masked_fill = masked_fill + + +def masked_fill_(xs: paddle.Tensor, + mask: paddle.Tensor, + value: Union[float, int]): + assert xs.shape == mask.shape + trues = paddle.ones_like(xs) * value + ret = paddle.where(mask, trues, xs) + paddle.assign(ret, output=xs) + + +if not hasattr(paddle.Tensor, 'masked_fill_'): + logger.warn( + "register user masked_fill_ to paddle.Tensor, remove this when fixed!") + paddle.Tensor.masked_fill_ = masked_fill_ + + +def fill_(xs: paddle.Tensor, value: Union[float, int]): + val = paddle.full_like(xs, value) + paddle.assign(val, output=xs) + + +if not hasattr(paddle.Tensor, 'fill_'): + logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") + paddle.Tensor.fill_ = fill_ + + +def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: + return paddle.tile(xs, size) + + +if not hasattr(paddle.Tensor, 'repeat'): + logger.warn( + "register user repeat to paddle.Tensor, remove this when fixed!") + paddle.Tensor.repeat = repeat + +if not hasattr(paddle.Tensor, 'softmax'): + logger.warn( + "register user softmax to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) + +if not hasattr(paddle.Tensor, 'sigmoid'): + logger.warn( + "register user sigmoid to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) + +if not hasattr(paddle.Tensor, 'relu'): + logger.warn("register user relu to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) + +########### hcak paddle.nn.functional ############# + + +def glu(x: paddle.Tensor, dim=-1) -> paddle.Tensor: + """The gated linear unit (GLU) activation.""" + a, b = x.split(2, axis=dim) + act_b = F.sigmoid(b) + return a * act_b + + +if not hasattr(paddle.nn.functional, 'glu'): + logger.warn( + "register user glu to paddle.nn.functional, remove this when fixed!") + setattr(paddle.nn.functional, 'glu', glu) + +# def softplus(x): +# """Softplus function.""" +# if hasattr(paddle.nn.functional, 'softplus'): +# #return paddle.nn.functional.softplus(x.float()).type_as(x) +# return paddle.nn.functional.softplus(x) +# else: +# raise NotImplementedError + +# def gelu_accurate(x): +# """Gaussian Error Linear Units (GELU) activation.""" +# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py +# if not hasattr(gelu_accurate, "_a"): +# gelu_accurate._a = math.sqrt(2 / math.pi) +# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a * +# (x + 0.044715 * paddle.pow(x, 3)))) + +# def gelu(x): +# """Gaussian Error Linear Units (GELU) activation.""" +# if hasattr(nn.functional, 'gelu'): +# #return nn.functional.gelu(x.float()).type_as(x) +# return nn.functional.gelu(x) +# else: +# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0))) + + +# hack loss +def ctc_loss(logits, + labels, + input_lengths, + label_lengths, + blank=0, + reduction='mean', + norm_by_times=True): + #logger.info("my ctc loss with norm by times") + ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403 + loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times, + input_lengths, label_lengths) + + loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) + logger.info(f"warpctc loss: {loss_out}/{loss_out.shape} ") + assert reduction in ['mean', 'sum', 'none'] + if reduction == 'mean': + loss_out = paddle.mean(loss_out / label_lengths) + elif reduction == 'sum': + loss_out = paddle.sum(loss_out) + logger.info(f"ctc loss: {loss_out}") + return loss_out + + +logger.warn( + "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!" +) +F.ctc_loss = ctc_loss + + +########### hcak paddle.nn ############# +class GLU(nn.Layer): + """Gated Linear Units (GLU) Layer""" + + def __init__(self, dim: int=-1): + super().__init__() + self.dim = dim + + def forward(self, xs): + return glu(xs, dim=self.dim) + + +if not hasattr(paddle.nn, 'GLU'): + logger.warn("register user GLU to paddle.nn, remove this when fixed!") + setattr(paddle.nn, 'GLU', GLU) + + +# TODO(Hui Zhang): remove this Layer +class ConstantPad2d(nn.Layer): + """Pads the input tensor boundaries with a constant value. + For N-dimensional padding, use paddle.nn.functional.pad(). + """ + + def __init__(self, padding: Union[tuple, list, int], value: float): + """ + Args: + paddle ([tuple]): the size of the padding. + If is int, uses the same padding in all boundaries. + If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom) + value ([flaot]): pad value + """ + self.padding = padding if isinstance(padding, + [tuple, list]) else [padding] * 4 + self.value = value + + def forward(self, xs: paddle.Tensor) -> paddle.Tensor: + return nn.functional.pad( + xs, + self.padding, + mode='constant', + value=self.value, + data_format='NCHW') + + +if not hasattr(paddle.nn, 'ConstantPad2d'): + logger.warn( + "register user ConstantPad2d to paddle.nn, remove this when fixed!") + setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d) diff --git a/deepspeech/modules/__init__.py b/deepspeech/modules/__init__.py index 973bc0624..61d5aa213 100644 --- a/deepspeech/modules/__init__.py +++ b/deepspeech/modules/__init__.py @@ -10,248 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import Union -from typing import Any - -import paddle -from paddle import nn -from paddle.nn import functional as F -from paddle.nn import initializer as I - -logger = logging.getLogger(__name__) -logger.warn = logging.warning - -# TODO(Hui Zhang): remove this hack -paddle.bool = 'bool' -paddle.float16 = 'float16' -paddle.float32 = 'float32' -paddle.float64 = 'float64' -paddle.int8 = 'int8' -paddle.int16 = 'int16' -paddle.int32 = 'int32' -paddle.int64 = 'int64' -paddle.uint8 = 'uint8' -paddle.complex64 = 'complex64' -paddle.complex128 = 'complex128' - -if not hasattr(paddle.Tensor, 'cat'): - logger.warn( - "override cat of paddle.Tensor if exists or register, remove this when fixed!" - ) - paddle.Tensor.cat = paddle.Tensor.concat - - -def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: - return xs.equal(paddle.to_tensor(ys, dtype=xs.dtype, place=xs.place)) - - -if not hasattr(paddle.Tensor, 'eq'): - logger.warn( - "override eq of paddle.Tensor if exists or register, remove this when fixed!" - ) - paddle.Tensor.eq = eq - - -def contiguous(xs: paddle.Tensor) -> paddle.Tensor: - return xs - - -if not hasattr(paddle.Tensor, 'contiguous'): - logger.warn( - "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" - ) - paddle.Tensor.contiguous = contiguous - - -def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: - nargs = len(args) - assert (nargs <= 1) - s = paddle.shape(xs) - if nargs == 1: - return s[args[0]] - else: - return s - - -#`to_static` do not process `size` property, maybe some `paddle` api dependent on it. -logger.warn( - "override size of paddle.Tensor " - "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" -) -paddle.Tensor.size = size - - -def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: - return xs.reshape(args) - - -if not hasattr(paddle.Tensor, 'view'): - logger.warn("register user view to paddle.Tensor, remove this when fixed!") - paddle.Tensor.view = view - - -def masked_fill(xs: paddle.Tensor, - mask: paddle.Tensor, - value: Union[float, int]): - assert xs.shape == mask.shape - trues = paddle.ones_like(xs) * value - xs = paddle.where(mask, trues, xs) - return xs - - -if not hasattr(paddle.Tensor, 'masked_fill'): - logger.warn( - "register user masked_fill to paddle.Tensor, remove this when fixed!") - paddle.Tensor.masked_fill = masked_fill - - -def masked_fill_(xs: paddle.Tensor, - mask: paddle.Tensor, - value: Union[float, int]): - assert xs.shape == mask.shape - trues = paddle.ones_like(xs) * value - ret = paddle.where(mask, trues, xs) - paddle.assign(ret, output=xs) - - -if not hasattr(paddle.Tensor, 'masked_fill_'): - logger.warn( - "register user masked_fill_ to paddle.Tensor, remove this when fixed!") - paddle.Tensor.masked_fill_ = masked_fill_ - - -def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: - return paddle.tile(xs, size) - - -if not hasattr(paddle.Tensor, 'repeat'): - logger.warn( - "register user repeat to paddle.Tensor, remove this when fixed!") - paddle.Tensor.repeat = repeat - -# def softplus(x): -# """Softplus function.""" -# if hasattr(paddle.nn.functional, 'softplus'): -# #return paddle.nn.functional.softplus(x.float()).type_as(x) -# return paddle.nn.functional.softplus(x) -# else: -# raise NotImplementedError - -# def gelu_accurate(x): -# """Gaussian Error Linear Units (GELU) activation.""" -# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py -# if not hasattr(gelu_accurate, "_a"): -# gelu_accurate._a = math.sqrt(2 / math.pi) -# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a * -# (x + 0.044715 * paddle.pow(x, 3)))) - -# def gelu(x): -# """Gaussian Error Linear Units (GELU) activation.""" -# if hasattr(nn.functional, 'gelu'): -# #return nn.functional.gelu(x.float()).type_as(x) -# return nn.functional.gelu(x) -# else: -# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0))) - - -def glu(x: paddle.Tensor, dim=-1) -> paddle.Tensor: - """The gated linear unit (GLU) activation.""" - a, b = x.split(2, axis=dim) - act_b = F.sigmoid(b) - return a * act_b - - -if not hasattr(paddle.nn.functional, 'glu'): - logger.warn( - "register user glu to paddle.nn.functional, remove this when fixed!") - setattr(paddle.nn.functional, 'glu', glu) - - -# TODO(Hui Zhang): remove this activation -class GLU(nn.Layer): - """Gated Linear Units (GLU) Layer""" - - def __init__(self, dim: int=-1): - super().__init__() - self.dim = dim - - def forward(self, xs): - return glu(xs, dim=self.dim) - - -if not hasattr(paddle.nn, 'GLU'): - logger.warn("register user GLU to paddle.nn, remove this when fixed!") - setattr(paddle.nn, 'GLU', GLU) - - -# TODO(Hui Zhang): remove this Layer -class ConstantPad2d(nn.Layer): - """Pads the input tensor boundaries with a constant value. - For N-dimensional padding, use paddle.nn.functional.pad(). - """ - - def __init__(self, padding: Union[tuple, list, int], value: float): - """ - Args: - paddle ([tuple]): the size of the padding. - If is int, uses the same padding in all boundaries. - If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom) - value ([flaot]): pad value - """ - self.padding = padding if isinstance(padding, - [tuple, list]) else [padding] * 4 - self.value = value - - def forward(self, xs: paddle.Tensor) -> paddle.Tensor: - return nn.functional.pad( - xs, - self.padding, - mode='constant', - value=self.value, - data_format='NCHW') - - -if not hasattr(paddle.nn, 'ConstantPad2d'): - logger.warn( - "register user ConstantPad2d to paddle.nn, remove this when fixed!") - setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d) - -if not hasattr(paddle, 'softmax'): - logger.warn("register user softmax to paddle, remove this when fixed!") - setattr(paddle, 'softmax', paddle.nn.functional.softmax) - -if not hasattr(paddle, 'sigmoid'): - logger.warn("register user softmax to paddle, remove this when fixed!") - setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) - - -# hack loss -def ctc_loss(logits, - labels, - input_lengths, - label_lengths, - blank=0, - reduction='mean', - norm_by_times=True): - #logger.info("my ctc loss with norm by times") - ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403 - loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times, - input_lengths, label_lengths) - - loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) - logger.info(f"warpctc loss: {loss_out}/{loss_out.shape} ") - assert reduction in ['mean', 'sum', 'none'] - if reduction == 'mean': - loss_out = paddle.mean(loss_out / label_lengths) - elif reduction == 'sum': - loss_out = paddle.sum(loss_out) - logger.info(f"ctc loss: {loss_out}") - return loss_out - - -logger.warn( - "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!" -) -F.ctc_loss = ctc_loss +# limitations under the License. \ No newline at end of file diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py index 827791f36..7769a7855 100644 --- a/deepspeech/modules/activation.py +++ b/deepspeech/modules/activation.py @@ -25,7 +25,7 @@ from paddle.nn import initializer as I logger = logging.getLogger(__name__) -__all__ = ["brelu", "LinearGLUBlock", "ConstantPad2d", "ConvGLUBlock"] +__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"] def brelu(x, t_min=0.0, t_max=24.0, name=None): @@ -50,33 +50,6 @@ class LinearGLUBlock(nn.Layer): return glu(self.fc(xs), dim=-1) -# TODO(Hui Zhang): remove this Layer -class ConstantPad2d(nn.Layer): - """Pads the input tensor boundaries with a constant value. - For N-dimensional padding, use paddle.nn.functional.pad(). - """ - - def __init__(self, padding: Union[tuple, list, int], value: float): - """ - Args: - paddle ([tuple]): the size of the padding. - If is int, uses the same padding in all boundaries. - If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom) - value ([flaot]): pad value - """ - self.padding = padding if isinstance(padding, - [tuple, list]) else [padding] * 4 - self.value = value - - def forward(self, xs: paddle.Tensor) -> paddle.Tensor: - return nn.functional.pad( - xs, - self.padding, - mode='constant', - value=self.value, - data_format='NCHW') - - class ConvGLUBlock(nn.Layer): def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0, dropout=0.): @@ -159,3 +132,18 @@ class ConvGLUBlock(nn.Layer): xs = self.layers(xs) # `[B, out_ch * 2, T ,1]` xs = xs + residual return xs + + +def get_activation(act): + """Return activation function.""" + # Lazy load to avoid unused import + activation_funcs = { + "hardtanh": paddle.nn.Hardtanh, + "tanh": paddle.nn.Tanh, + "relu": paddle.nn.ReLU, + "selu": paddle.nn.SELU, + "swish": paddle.nn.Swish, + "gelu": paddle.nn.GELU + } + + return activation_funcs[act]() diff --git a/deepspeech/modules/attention.py b/deepspeech/modules/attention.py index d75a7f841..f9a91b94e 100644 --- a/deepspeech/modules/attention.py +++ b/deepspeech/modules/attention.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Multi-Head Attention layer definition.""" import math import logging @@ -26,6 +25,10 @@ logger = logging.getLogger(__name__) __all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"] +# Relative Positional Encodings +# https://www.jianshu.com/p/c0608efcc26f +# https://zhuanlan.zhihu.com/p/344604604 + class MultiHeadedAttention(nn.Layer): """Multi-Head Attention layer.""" @@ -89,8 +92,8 @@ class MultiHeadedAttention(nn.Layer): mask (paddle.Tensor): Mask, size (#batch, 1, time2) or (#batch, time1, time2). Returns: - paddle.Tensor: Transformed value (#batch, time1, d_model) - weighted by the attention score (#batch, time1, time2). + paddle.Tensor: Transformed value weighted + by the attention score, (#batch, time1, d_model). """ n_batch = value.size(0) if mask is not None: @@ -126,8 +129,8 @@ class MultiHeadedAttention(nn.Layer): torch.Tensor: Output tensor (#batch, time1, d_model). """ q, k, v = self.forward_qkv(query, key, value) - scores = paddle.matmul(q, k.transpose( - [0, 1, 3, 2])) / math.sqrt(self.d_k) + scores = paddle.matmul(q, + k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask) @@ -147,76 +150,78 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False) # these two learnable bias are used in matrix c and matrix d # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) - self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) - torch.nn.init.xavier_uniform_(self.pos_bias_u) - torch.nn.init.xavier_uniform_(self.pos_bias_v) + #self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + #self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + #torch.nn.init.xavier_uniform_(self.pos_bias_u) + #torch.nn.init.xavier_uniform_(self.pos_bias_v) + pos_bias_u = self.create_parameter( + [self.h, self.d_k], default_initializer=I.XavierUniform()) + self.add_parameter('pos_bias_u', pos_bias_u) + pos_bias_v = self.create_parameter( + (self.h, self.d_k), default_initializer=I.XavierUniform()) + self.add_parameter('pos_bias_v', pos_bias_v) def rel_shift(self, x, zero_triu: bool=False): """Compute relative positinal encoding. Args: - x (torch.Tensor): Input tensor (batch, time, size). + x (paddle.Tensor): Input tensor (batch, head, time1, time1). zero_triu (bool): If true, return the lower triangular part of the matrix. Returns: - torch.Tensor: Output tensor. + paddle.Tensor: Output tensor. (batch, head, time1, time1) """ + zero_pad = paddle.zeros( + (x.size(0), x.size(1), x.size(2), 1), dtype=x.dtype) + x_padded = paddle.cat([zero_pad, x], dim=-1) - zero_pad = torch.zeros( - (x.size()[0], x.size()[1], x.size()[2], 1), - device=x.device, - dtype=x.dtype) - x_padded = torch.cat([zero_pad, x], dim=-1) - - x_padded = x_padded.view(x.size()[0], - x.size()[1], x.size(3) + 1, x.size(2)) - x = x_padded[:, :, 1:].view_as(x) + x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] if zero_triu: - ones = torch.ones((x.size(2), x.size(3))) - x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + ones = paddle.ones((x.size(2), x.size(3))) + x = x * paddle.tril(ones, x.size(3) - x.size(2))[None, None, :, :] return x def forward(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - pos_emb: torch.Tensor, - mask: Optional[torch.Tensor]): + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + pos_emb: paddle.Tensor, + mask: Optional[paddle.Tensor]): """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - pos_emb (torch.Tensor): Positional embedding tensor - (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + query (paddle.Tensor): Query tensor (#batch, time1, size). + key (paddle.Tensor): Key tensor (#batch, time2, size). + value (paddle.Tensor): Value tensor (#batch, time2, size). + pos_emb (paddle.Tensor): Positional embedding tensor + (#batch, time1, size). + mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2). Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). + paddle.Tensor: Output tensor (#batch, time1, d_model). """ q, k, v = self.forward_qkv(query, key, value) - q = q.transpose(1, 2) # (batch, time1, head, d_k) + q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) n_batch_pos = pos_emb.size(0) p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) - p = p.transpose(1, 2) # (batch, head, time1, d_k) + p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) # (batch, head, time1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) # compute attention score # first compute matrix a and matrix c # as described in https://arxiv.org/abs/1901.02860 Section 3.3 # (batch, head, time1, time2) - matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2])) # compute matrix b and matrix d # (batch, head, time1, time2) - matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2])) # Remove rel_shift since it is useless in speech recognition, # and it requires special attention for streaming. # matrix_bd = self.rel_shift(matrix_bd) diff --git a/deepspeech/modules/embedding.py b/deepspeech/modules/embedding.py index 114bcd25f..be2103292 100644 --- a/deepspeech/modules/embedding.py +++ b/deepspeech/modules/embedding.py @@ -27,9 +27,6 @@ logger = logging.getLogger(__name__) __all__ = ["PositionalEncoding", "RelPositionalEncoding"] -# TODO(Hui Zhang): remove this hack -paddle.float32 = 'float32' - class PositionalEncoding(nn.Layer): def __init__(self, @@ -122,11 +119,11 @@ class RelPositionalEncoding(PositionalEncoding): paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`). """ - T = paddle.shape()[1] - assert offset + T < self.max_len - #assert offset + x.size(1) < self.max_len + #T = paddle.shape()[1] + #assert offset + T < self.max_len + assert offset + x.size(1) < self.max_len #self.pe = self.pe.to(x.device) x = x * self.xscale - #pos_emb = self.pe[:, offset:offset + x.size(1)] - pos_emb = self.pe[:, offset:offset + T] + pos_emb = self.pe[:, offset:offset + x.size(1)] + #pos_emb = self.pe[:, offset:offset + T] return self.dropout(x), self.dropout(pos_emb) diff --git a/deepspeech/utils/common.py b/deepspeech/utils/common.py new file mode 100644 index 000000000..801b32e95 --- /dev/null +++ b/deepspeech/utils/common.py @@ -0,0 +1,113 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unility functions for Transformer.""" +import math +import logging +from typing import Tuple, List + +import paddle + +logger = logging.getLogger(__name__) + +__all__ = ["pad_list", "add_sos_eos", "remove_duplicates_and_blank", "log_add"] + +IGNORE_ID = -1 + + +def pad_list(xs: List[paddle.Tensor], pad_value: int): + """Perform padding for the list of tensors. + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + Examples: + >>> x = [paddle.ones(4), paddle.ones(2), paddle.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + """ + n_batch = len(xs) + max_len = max([x.size(0) for x in xs]) + pad = paddle.zeros(n_batch, max_len, dtype=xs[0].dtype) + pad = pad.fill_(pad_value) + for i in range(n_batch): + pad[i, :xs[i].size(0)] = xs[i] + + return pad + + +def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, + ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Add and labels. + Args: + ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax) + sos (int): index of + eos (int): index of + ignore_id (int): index of padding + Returns: + ys_in (paddle.Tensor) : (B, Lmax + 1) + ys_out (paddle.Tensor) : (B, Lmax + 1) + Examples: + >>> sos_id = 10 + >>> eos_id = 11 + >>> ignore_id = -1 + >>> ys_pad + tensor([[ 1, 2, 3, 4, 5], + [ 4, 5, 6, -1, -1], + [ 7, 8, 9, -1, -1]], dtype=paddle.int32) + >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) + >>> ys_in + tensor([[10, 1, 2, 3, 4, 5], + [10, 4, 5, 6, 11, 11], + [10, 7, 8, 9, 11, 11]]) + >>> ys_out + tensor([[ 1, 2, 3, 4, 5, 11], + [ 4, 5, 6, 11, -1, -1], + [ 7, 8, 9, 11, -1, -1]]) + """ + _sos = paddle.to_tensor( + [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) + _eos = paddle.to_tensor( + [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) + ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys + ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] + ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] + return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) + + +def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def log_add(args: List[int]) -> float: + """ + Stable log add + """ + if all(a == -float('inf') for a in args): + return -float('inf') + a_max = max(args) + lsp = math.log(sum(math.exp(a - a_max) for a in args)) + return a_max + lsp diff --git a/deepspeech/utils/layer_tools.py b/deepspeech/utils/layer_tools.py index 46a354761..20c8ccf60 100644 --- a/deepspeech/utils/layer_tools.py +++ b/deepspeech/utils/layer_tools.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np + from paddle import nn __all__ = [ @@ -36,7 +37,7 @@ def gradient_norm(layer: nn.Layer): grad_norm_dict = {} for name, param in layer.state_dict().items(): if param.trainable: - grad = param.gradient() + grad = param.gradient() # return numpy.ndarray grad_norm_dict[name] = np.linalg.norm(grad) / grad.size return grad_norm_dict diff --git a/deepspeech/utils/metric.py b/deepspeech/utils/metric.py new file mode 100644 index 000000000..e53b24056 --- /dev/null +++ b/deepspeech/utils/metric.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import logging +from typing import Tuple, List + +import paddle + +logger = logging.getLogger(__name__) + +__all__ = ["th_accuracy"] + + +def th_accuracy(pad_outputs: paddle.Tensor, + pad_targets: paddle.Tensor, + ignore_label: int) -> float: + """Calculate accuracy. + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + Returns: + float: Accuracy value (0.0 - 1.0). + """ + pad_pred = pad_outputs.view( + pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2) + mask = pad_targets != ignore_label + numerator = paddle.sum( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + denominator = paddle.sum(mask) + return float(numerator) / float(denominator) diff --git a/tests/network_test.py b/tests/network_test.py index 7e8d62c2b..31c5efc79 100644 --- a/tests/network_test.py +++ b/tests/network_test.py @@ -15,10 +15,9 @@ import paddle import numpy as np -from deepspeech.models.network import DeepSpeech2 +from deepspeech.models.deepspeech2 import DeepSpeech2Model if __name__ == '__main__': - batch_size = 2 feat_dim = 161 max_len = 100 @@ -28,15 +27,10 @@ if __name__ == '__main__': text = np.array([[1, 2], [1, 2]], dtype='int32') text_len = np.array([2] * batch_size, dtype='int32') - place = paddle.CUDAPlace(0) - audio = paddle.to_tensor( - audio, dtype='float32', place=place, stop_gradient=True) - audio_len = paddle.to_tensor( - audio_len, dtype='int64', place=place, stop_gradient=True) - text = paddle.to_tensor( - text, dtype='int32', place=place, stop_gradient=True) - text_len = paddle.to_tensor( - text_len, dtype='int64', place=place, stop_gradient=True) + audio = paddle.to_tensor(audio, dtype='float32') + audio_len = paddle.to_tensor(audio_len, dtype='int64') + text = paddle.to_tensor(text, dtype='int32') + text_len = paddle.to_tensor(text_len, dtype='int64') print(audio.shape) print(audio_len.shape) @@ -44,7 +38,7 @@ if __name__ == '__main__': print(text_len.shape) print("-----------------") - model = DeepSpeech2( + model = DeepSpeech2Model( feat_size=feat_dim, dict_size=10, num_conv_layers=2, @@ -56,7 +50,7 @@ if __name__ == '__main__': print('probs.shape', probs.shape) print("-----------------") - model2 = DeepSpeech2( + model2 = DeepSpeech2Model( feat_size=feat_dim, dict_size=10, num_conv_layers=2, @@ -68,7 +62,7 @@ if __name__ == '__main__': print('probs.shape', probs.shape) print("-----------------") - model3 = DeepSpeech2( + model3 = DeepSpeech2Model( feat_size=feat_dim, dict_size=10, num_conv_layers=2, @@ -80,7 +74,7 @@ if __name__ == '__main__': print('probs.shape', probs.shape) print("-----------------") - model4 = DeepSpeech2( + model4 = DeepSpeech2Model( feat_size=feat_dim, dict_size=10, num_conv_layers=2, @@ -92,7 +86,7 @@ if __name__ == '__main__': print('probs.shape', probs.shape) print("-----------------") - model5 = DeepSpeech2( + model5 = DeepSpeech2Model( feat_size=feat_dim, dict_size=10, num_conv_layers=2,