diff --git a/.gitignore b/.gitignore index 93b7544a4..db91b19e4 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,7 @@ tools/venv *.log *.pdmodel *.pdiparams* +*.zip +*.tar +*.tar.gz +.ipynb_checkpoints diff --git a/.notebook/mask_and_masked_fill_test.ipynb b/.notebook/mask_and_masked_fill_test.ipynb new file mode 100644 index 000000000..265ec536b --- /dev/null +++ b/.notebook/mask_and_masked_fill_test.ipynb @@ -0,0 +1,449 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "primary-organic", + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "stopped-semester", + "metadata": {}, + "outputs": [], + "source": [ + "def mask_finished_scores(score: torch.Tensor,\n", + " flag: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " If a sequence is finished, we only allow one alive branch. This function\n", + " aims to give one branch a zero score and the rest -inf score.\n", + " Args:\n", + " score (torch.Tensor): A real value array with shape\n", + " (batch_size * beam_size, beam_size).\n", + " flag (torch.Tensor): A bool array with shape\n", + " (batch_size * beam_size, 1).\n", + " Returns:\n", + " torch.Tensor: (batch_size * beam_size, beam_size).\n", + " \"\"\"\n", + " beam_size = score.size(-1)\n", + " zero_mask = torch.zeros_like(flag, dtype=torch.bool)\n", + " if beam_size > 1:\n", + " unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),\n", + " dim=1)\n", + " finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),\n", + " dim=1)\n", + " else:\n", + " unfinished = zero_mask\n", + " finished = flag\n", + " print(unfinished)\n", + " print(finished)\n", + " score.masked_fill_(unfinished, -float('inf'))\n", + " score.masked_fill_(finished, 0)\n", + " return score" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "agreed-portuguese", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ True],\n", + " [False]])\n", + "tensor([[-0.8841, 0.7381, -0.9986],\n", + " [ 0.2675, -0.7971, 0.3798]])\n", + "tensor([[ True, True],\n", + " [False, False]])\n" + ] + } + ], + "source": [ + "score = torch.randn((2, 3))\n", + "flag = torch.ones((2, 1), dtype=torch.bool)\n", + "flag[1] = False\n", + "print(flag)\n", + "print(score)\n", + "print(flag.repeat([1, 2]))" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "clean-aspect", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[False, True, True],\n", + " [False, False, False]])\n", + "tensor([[ True, False, False],\n", + " [False, False, False]])\n", + "tensor([[ 0.0000, -inf, -inf],\n", + " [ 0.2675, -0.7971, 0.3798]])\n", + "tensor([[ 0.0000, -inf, -inf],\n", + " [ 0.2675, -0.7971, 0.3798]])\n" + ] + } + ], + "source": [ + "r = mask_finished_scores(score, flag)\n", + "print(r)\n", + "print(score)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "thrown-airline", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor(shape=[2, 1], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", + " [[True ],\n", + " [False]])\n", + "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[ 2.05994511, 1.87704289, 0.01988174],\n", + " [-0.40165186, 0.77547729, -0.64469045]])\n", + "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", + " [[True , True ],\n", + " [False, False]])\n" + ] + } + ], + "source": [ + "import paddle\n", + "\n", + "score = paddle.randn((2, 3))\n", + "flag = paddle.ones((2, 1), dtype='bool')\n", + "flag[1] = False\n", + "print(flag)\n", + "print(score)\n", + "print(flag.tile([1, 2]))" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "internal-patent", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", + " [[False, True , True ],\n", + " [False, False, False]])\n", + "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", + " [[True , False, False],\n", + " [False, False, False]])\n", + "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[ 2.05994511, 1.87704289, 0.01988174],\n", + " [-0.40165186, 0.77547729, -0.64469045]])\n", + "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[ 2.05994511, 1.87704289, 0.01988174],\n", + " [-0.40165186, 0.77547729, -0.64469045]])\n", + "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[ 2.05994511, -inf. , -inf. ],\n", + " [-0.40165186, 0.77547729, -0.64469045]])\n", + "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[ 2.05994511, -inf. , -inf. ],\n", + " [-0.40165186, 0.77547729, -0.64469045]])\n", + "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[ 2.05994511, -inf. , -inf. ],\n", + " [-0.40165186, 0.77547729, -0.64469045]])\n", + "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[ 0. , -inf. , -inf. ],\n", + " [-0.40165186, 0.77547729, -0.64469045]])\n", + "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", + " [[ 0. , -inf. , -inf. ],\n", + " [-0.40165186, 0.77547729, -0.64469045]])\n" + ] + } + ], + "source": [ + "paddle.bool = 'bool'\n", + "\n", + "def masked_fill(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n", + " print(xs)\n", + " trues = paddle.ones_like(xs) * value\n", + " assert xs.shape == mask.shape\n", + " xs = paddle.where(mask, trues, xs)\n", + " return xs\n", + "\n", + "def masked_fill_(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n", + " print('x', xs)\n", + " trues = paddle.ones_like(xs) * value\n", + " assert xs.shape == mask.shape\n", + " ret = paddle.where(mask, trues, xs)\n", + " print('2', xs)\n", + " paddle.assign(ret, output=xs)\n", + " print('3', xs)\n", + "\n", + "paddle.Tensor.masked_fill = masked_fill\n", + "paddle.Tensor.masked_fill_ = masked_fill_\n", + "\n", + "def mask_finished_scores_pd(score: paddle.Tensor,\n", + " flag: paddle.Tensor) -> paddle.Tensor:\n", + " \"\"\"\n", + " If a sequence is finished, we only allow one alive branch. This function\n", + " aims to give one branch a zero score and the rest -inf score.\n", + " Args:\n", + " score (torch.Tensor): A real value array with shape\n", + " (batch_size * beam_size, beam_size).\n", + " flag (torch.Tensor): A bool array with shape\n", + " (batch_size * beam_size, 1).\n", + " Returns:\n", + " torch.Tensor: (batch_size * beam_size, beam_size).\n", + " \"\"\"\n", + " beam_size = score.shape[-1]\n", + " zero_mask = paddle.zeros_like(flag, dtype=paddle.bool)\n", + " if beam_size > 1:\n", + " unfinished = paddle.concat((zero_mask, flag.tile([1, beam_size - 1])),\n", + " axis=1)\n", + " finished = paddle.concat((flag, zero_mask.tile([1, beam_size - 1])),\n", + " axis=1)\n", + " else:\n", + " unfinished = zero_mask\n", + " finished = flag\n", + " print(unfinished)\n", + " print(finished)\n", + " \n", + " #score.masked_fill_(unfinished, -float('inf'))\n", + " #score.masked_fill_(finished, 0)\n", + "# infs = paddle.ones_like(score) * -float('inf')\n", + "# score = paddle.where(unfinished, infs, score)\n", + "# score = paddle.where(finished, paddle.zeros_like(score), score)\n", + "\n", + "# score = score.masked_fill(unfinished, -float('inf'))\n", + "# score = score.masked_fill(finished, 0)\n", + " score.masked_fill_(unfinished, -float('inf'))\n", + " score.masked_fill_(finished, 0)\n", + " return score\n", + "\n", + "r = mask_finished_scores_pd(score, flag)\n", + "print(r)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "vocal-prime", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "score.value" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "bacterial-adolescent", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Union, Any" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "id": "absent-fiber", + "metadata": {}, + "outputs": [], + "source": [ + "def repeat(xs : paddle.Tensor, *size: Any):\n", + " print(size)\n", + " return paddle.tile(xs, size)\n", + "paddle.Tensor.repeat = repeat" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "id": "material-harbor", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 2)\n", + "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", + " [[True , True ],\n", + " [False, False]])\n" + ] + } + ], + "source": [ + "flag = paddle.ones((2, 1), dtype='bool')\n", + "flag[1] = False\n", + "print(flag.repeat(1, 2))" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "acute-brighton", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", + " [1]), 2)\n", + "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", + " [[True , True ],\n", + " [False, False]])\n" + ] + } + ], + "source": [ + "flag = paddle.ones((2, 1), dtype='bool')\n", + "flag[1] = False\n", + "print(flag.repeat(paddle.to_tensor(1), 2))" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "european-rugby", + "metadata": {}, + "outputs": [], + "source": [ + "def size(xs, *args: int):\n", + " nargs = len(args)\n", + " s = paddle.shape(xs)\n", + " assert(nargs <= 1)\n", + " if nargs == 1:\n", + " return s[args[0]]\n", + " else:\n", + " return s\n", + "paddle.Tensor.size = size" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "moral-special", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tensor(shape=[2], dtype=int32, place=CPUPlace, stop_gradient=True,\n", + " [2, 1])" + ] + }, + "execution_count": 86, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "flag.size()" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "id": "ahead-coach", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", + " [1])" + ] + }, + "execution_count": 87, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "flag.size(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "incomplete-fitness", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", + " [2])" + ] + }, + "execution_count": 88, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "flag.size(0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "upset-connectivity", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/deepspeech/modules/__init__.py b/deepspeech/modules/__init__.py index 185a92b8d..c8c35c8ba 100644 --- a/deepspeech/modules/__init__.py +++ b/deepspeech/modules/__init__.py @@ -11,3 +11,206 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typeing import Union +from typeing import Any + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +logger = logging.getLogger(__name__) + +# TODO(Hui Zhang): remove this hack +paddle.bool = 'bool' +paddle.float16 = 'float16' +paddle.float32 = 'float32' +paddle.float64 = 'float64' +paddle.int8 = 'int8' +paddle.int16 = 'int16' +paddle.int32 = 'int32' +paddle.int64 = 'int64' +paddle.uint8 = 'uint8' +paddle.complex64 = 'complex64' +paddle.complex128 = 'complex128' + +if not hasattr(paddle.Tensor, 'cat'): + logger.warn( + "override cat of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.cat = paddle.Tensor.concat + + +def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: + nargs = len(args) + assert (nargs <= 1) + s = paddle.shape(xs) + if nargs == 1: + return s[args] + else: + return s + + +if not hasattr(paddle.Tensor, 'size'): + logger.warn( + "override size of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.size = size + + +def masked_fill(xs: paddle.Tensor, + mask: paddle.Tensor, + value: Union[float, int]): + assert xs.shape == mask.shape + trues = paddle.ones_like(xs) * value + xs = paddle.where(mask, trues, xs) + return xs + + +if not hasattr(paddle.Tensor, 'masked_fill'): + logger.warn( + "register user masked_fill to paddle.Tensor, remove this when fixed!") + paddle.Tensor.masked_fill = masked_fill + + +def masked_fill_(xs: paddle.Tensor, + mask: paddle.Tensor, + value: Union[float, int]): + assert xs.shape == mask.shape + trues = paddle.ones_like(xs) * value + ret = paddle.where(mask, trues, xs) + paddle.assign(ret, output=xs) + + +if not hasattr(paddle.Tensor, 'masked_fill_'): + logger.warn( + "register user masked_fill_ to paddle.Tensor, remove this when fixed!") + paddle.Tensor.masked_fill_ = masked_fill_ + + +def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: + return paddle.tile(xs, size) + + +if not hasattr(paddle.Tensor, 'repeat'): + logger.warn( + "register user repeat to paddle.Tensor, remove this when fixed!") + paddle.Tensor.repeat = repeat + +# def softplus(x): +# """Softplus function.""" +# if hasattr(paddle.nn.functional, 'softplus'): +# #return paddle.nn.functional.softplus(x.float()).type_as(x) +# return paddle.nn.functional.softplus(x) +# else: +# raise NotImplementedError + +# def gelu_accurate(x): +# """Gaussian Error Linear Units (GELU) activation.""" +# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py +# if not hasattr(gelu_accurate, "_a"): +# gelu_accurate._a = math.sqrt(2 / math.pi) +# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a * +# (x + 0.044715 * paddle.pow(x, 3)))) + +# def gelu(x): +# """Gaussian Error Linear Units (GELU) activation.""" +# if hasattr(nn.functional, 'gelu'): +# #return nn.functional.gelu(x.float()).type_as(x) +# return nn.functional.gelu(x) +# else: +# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0))) + + +def glu(x: paddle.Tensor, dim=-1) -> paddle.Tensor: + """The gated linear unit (GLU) activation.""" + a, b = x.split(2, axis=dim) + act_b = F.sigmoid(b) + return a * act_b + + +if not hasattr(paddle.nn.functional, 'glu'): + logger.warn( + "register user glu to paddle.nn.functional, remove this when fixed!") + setattr(paddle.nn.functional, 'glu', glu) + + +# TODO(Hui Zhang): remove this activation +class GLU(nn.Layer): + """Gated Linear Units (GLU) Layer""" + + def __init__(self, dim: int=-1): + super().__init__() + self.dim = dim + + def forward(self, xs): + return glu(xs, dim=self.dim) + + +if not hasattr(paddle.nn, 'GLU'): + logger.warn("register user GLU to paddle.nn, remove this when fixed!") + setattr(paddle.nn, 'GLU', GLU) + + +# TODO(Hui Zhang): remove this Layer +class ConstantPad2d(nn.Layer): + """Pads the input tensor boundaries with a constant value. + For N-dimensional padding, use paddle.nn.functional.pad(). + """ + + def __init__(self, padding: Union[tuple, list, int], value: float): + """ + Args: + paddle ([tuple]): the size of the padding. + If is int, uses the same padding in all boundaries. + If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom) + value ([flaot]): pad value + """ + self.padding = padding if isinstance(padding, + [tuple, list]) else [padding] * 4 + self.value = value + + def forward(self, xs: paddle.Tensor) -> paddle.Tensor: + return nn.functional.pad( + xs, + self.padding, + mode='constant', + value=self.value, + data_format='NCHW') + + +if not hasattr(paddle.nn, 'ConstantPad2d'): + logger.warn( + "register user ConstantPad2d to paddle.nn, remove this when fixed!") + setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d) + + +# hack loss +def ctc_loss(logits, + labels, + input_lengths, + label_lengths, + blank=0, + reduction='mean', + norm_by_times=True): + #logger.info("my ctc loss with norm by times") + ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403 + loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times, + input_lengths, label_lengths) + + loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) + logger.info(f"warpctc loss: {loss_out}/{loss_out.shape} ") + assert reduction in ['mean', 'sum', 'none'] + if reduction == 'mean': + loss_out = paddle.mean(loss_out / label_lengths) + elif reduction == 'sum': + loss_out = paddle.sum(loss_out) + logger.info(f"ctc loss: {loss_out}") + return loss_out + + +logger.warn( + "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!" +) +F.ctc_loss = ctc_loss diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py index ecaca5bca..827791f36 100644 --- a/deepspeech/modules/activation.py +++ b/deepspeech/modules/activation.py @@ -25,9 +25,7 @@ from paddle.nn import initializer as I logger = logging.getLogger(__name__) -__all__ = [ - "brelu", "glu", "GLU", "LinearGLUBlock", "ConstantPad2d", "ConvGLUBlock" -] +__all__ = ["brelu", "LinearGLUBlock", "ConstantPad2d", "ConvGLUBlock"] def brelu(x, t_min=0.0, t_max=24.0, name=None): @@ -37,61 +35,6 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): return x.maximum(t_min).minimum(t_max) -# def softplus(x): -# """Softplus function.""" -# if hasattr(paddle.nn.functional, 'softplus'): -# #return paddle.nn.functional.softplus(x.float()).type_as(x) -# return paddle.nn.functional.softplus(x) -# else: -# raise NotImplementedError - -# def gelu_accurate(x): -# """Gaussian Error Linear Units (GELU) activation.""" -# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py -# if not hasattr(gelu_accurate, "_a"): -# gelu_accurate._a = math.sqrt(2 / math.pi) -# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a * -# (x + 0.044715 * paddle.pow(x, 3)))) - -# def gelu(x): -# """Gaussian Error Linear Units (GELU) activation.""" -# if hasattr(nn.functional, 'gelu'): -# #return nn.functional.gelu(x.float()).type_as(x) -# return nn.functional.gelu(x) -# else: -# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0))) - - -# TODO(Hui Zhang): remove this activation -def glu(x, dim=-1): - """The gated linear unit (GLU) activation.""" - if hasattr(nn.functional, 'glu'): - return nn.functional.glu(x) - else: - a, b = x.split(2, axis=dim) - act_b = F.sigmoid(b) - return a * act_b - - -# TODO(Hui Zhang): remove this activation -if not hasattr(nn.functional, 'glu'): - logger.warn( - "register user glu to paddle.nn.functional, remove this when fixed!") - setattr(nn.functional, 'glu', glu) - - -# TODO(Hui Zhang): remove this activation -class GLU(nn.Layer): - """Gated Linear Units (GLU) Layer""" - - def __init__(self, dim: int=-1): - super().__init__() - self.dim = dim - - def forward(self, xs): - return glu(xs, dim=self.dim) - - class LinearGLUBlock(nn.Layer): """A linear Gated Linear Units (GLU) block.""" diff --git a/deepspeech/modules/conformer_convolution.py b/deepspeech/modules/conformer_convolution.py index 4c3eb9f4f..c3e3e052e 100644 --- a/deepspeech/modules/conformer_convolution.py +++ b/deepspeech/modules/conformer_convolution.py @@ -22,10 +22,6 @@ from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I -# init F.glu func -# TODO(Hui Zhang): remove this line -import deepspeech.modules.activation - logger = logging.getLogger(__name__) __all__ = ['ConvolutionModule'] diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index bf06b6da1..9e1d34a89 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -24,34 +24,6 @@ logger = logging.getLogger(__name__) __all__ = ['CTCLoss', "LabelSmoothingLoss"] -# TODO(Hui Zhang): remove this hack, when `norm_by_times=True` is added -def ctc_loss(logits, - labels, - input_lengths, - label_lengths, - blank=0, - reduction='mean', - norm_by_times=True): - #logger.info("my ctc loss with norm by times") - ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403 - loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times, - input_lengths, label_lengths) - - loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) - logger.info(f"warpctc loss: {loss_out}/{loss_out.shape} ") - assert reduction in ['mean', 'sum', 'none'] - if reduction == 'mean': - loss_out = paddle.mean(loss_out / label_lengths) - elif reduction == 'sum': - loss_out = paddle.sum(loss_out) - logger.info(f"ctc loss: {loss_out}") - return loss_out - - -# TODO(Hui Zhang): remove this hack -F.ctc_loss = ctc_loss - - class CTCLoss(nn.Layer): def __init__(self, blank=0, reduction='sum'): super().__init__() @@ -149,12 +121,14 @@ class LabelSmoothingLoss(nn.Layer): ignore = target == self.padding_idx # (B,) ignore = ignore.cast(target.dtype) - target = target * (1 - ignore) # avoid -1 index + #target = target * (1 - ignore) # avoid -1 index + target = target.masked_fill(ignore, 0) # avoid -1 index true_dist += F.one_hot(target, self.size) * self.confidence kl = self.criterion(F.log_softmax(x, axis=1), true_dist) total = len(target) - int(ignore.sum()) denom = total if self.normalize_length else B - numer = (kl * (1 - ignore)).sum() + #numer = (kl * (1 - ignore)).sum() + numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum() return numer / denom diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index 0f136403f..7d30f060d 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -25,6 +25,21 @@ __all__ = ['sequence_mask'] def sequence_mask(x_len, max_len=None, dtype='float32'): + """[summary] + + Args: + x_len ([paddle.Tensor]): xs lenght, [B] + max_len ([type], optional): max sequence length. Defaults to None. + dtype (str, optional): mask data type. Defaults to 'float32'. + + Returns: + paddle.Tensor: [B, Tmax] + + Examples: + >>> sequence_mask([2, 4]) + [[1., 1., 0., 0.], + [1., 1., 1., 1.]] + """ max_len = max_len or x_len.max() x_len = paddle.unsqueeze(x_len, -1) row_vector = paddle.arange(max_len) @@ -33,3 +48,230 @@ def sequence_mask(x_len, max_len=None, dtype='float32'): mask = row_vector > x_len # a bug, broadcast 的时候出错了 mask = paddle.cast(mask, dtype) return mask + + +def subsequent_mask( + size: int, ) -> paddle.Tensor: + """Create mask for subsequent steps (size, size). + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + Args: + size (int): size of mask + Returns: + paddle.Tensor: mask + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = paddle.ones([size, size], dtype=paddle.bool) + return paddle.tril(ret) + + +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int=-1, ) -> paddle.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + Returns: + paddle.Tensor: mask + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros([size, size], dtype=paddle.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + +def add_optional_chunk_mask(xs: paddle.Tensor, + masks: paddle.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, + static_chunk_size: int, + num_decoding_left_chunks: int): + """ Apply optional mask for encoder. + Args: + xs (paddle.Tensor): padded input, (B, L, D), L for max length + mask (paddle.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + num_decoding_left_chunks (int): number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + paddle.Tensor: chunk mask of the input xs. + """ + # Whether to use chunk mask or not + if use_dynamic_chunk: + max_len = xs.shape[1] + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = num_decoding_left_chunks + else: + # chunk size is either [1, 25] or full context(max_len). + # Since we use 4 times subsampling and allow up to 1s(100 frames) + # delay, the maximum frame is 100 / 4 = 25. + chunk_size = int(paddle.randint(1, max_len, (1, ))) + num_left_chunks = -1 + if chunk_size > max_len // 2: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = int( + paddle.randint(0, max_left_chunks, (1, ))) + chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size, + num_left_chunks) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks + chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size, + num_left_chunks) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + else: + chunk_masks = masks + return chunk_masks + + +def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: + """Make mask tensor containing indices of padded part. + See description of make_non_pad_mask. + Args: + lengths (paddle.Tensor): Batch of lengths (B,). + Returns: + paddle.Tensor: Mask tensor containing indices of padded part. + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = int(lengths.shape[0]) + max_len = int(lengths.max()) + seq_range = paddle.arange(0, max_len, dtype=paddle.int64) + seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len]) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + + +def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: + """Make mask tensor containing indices of non-padded part. + The sequences in a batch may have different lengths. To enable + batch computing, padding is need to make all sequence in same + size. To avoid the padding part pass value to context dependent + block such as attention or convolution , this padding part is + masked. + This pad_mask is used in both encoder and decoder. + 1 for non-padded part and 0 for padded part. + Args: + lengths (paddle.Tensor): Batch of lengths (B,). + Returns: + paddle.Tensor: mask tensor containing indices of padded part. + Examples: + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + """ + return ~make_pad_mask(lengths) + + +def mask_finished_scores(score: paddle.Tensor, + flag: paddle.Tensor) -> paddle.Tensor: + """ + If a sequence is finished, we only allow one alive branch. This function + aims to give one branch a zero score and the rest -inf score. + Args: + score (paddle.Tensor): A real value array with shape + (batch_size * beam_size, beam_size). + flag (paddle.Tensor): A bool array with shape + (batch_size * beam_size, 1). + Returns: + paddle.Tensor: (batch_size * beam_size, beam_size). + Examples: + flag: tensor([[ True], + [False]]) + score: tensor([[-0.3666, -0.6664, 0.6019], + [-1.1490, -0.2948, 0.7460]]) + unfinished: tensor([[False, True, True], + [False, False, False]]) + finished: tensor([[ True, False, False], + [False, False, False]]) + return: tensor([[ 0.0000, -inf, -inf], + [-1.1490, -0.2948, 0.7460]]) + """ + beam_size = score.shape[-1] + zero_mask = paddle.zeros_like(flag, dtype=paddle.bool) + if beam_size > 1: + unfinished = paddle.concat( + (zero_mask, flag.tile([1, beam_size - 1])), axis=1) + finished = paddle.concat( + (flag, zero_mask.tile([1, beam_size - 1])), axis=1) + else: + unfinished = zero_mask + finished = flag + + # infs = paddle.ones_like(score) * -float('inf') + # score = paddle.where(unfinished, infs, score) + # score = paddle.where(finished, paddle.zeros_like(score), score) + score.masked_fill_(unfinished, -float('inf')) + score.masked_fill_(finished, 0) + return score + + +def mask_finished_preds(pred: paddle.Tensor, flag: paddle.Tensor, + eos: int) -> paddle.Tensor: + """ + If a sequence is finished, all of its branch should be + Args: + pred (paddle.Tensor): A int array with shape + (batch_size * beam_size, beam_size). + flag (paddle.Tensor): A bool array with shape + (batch_size * beam_size, 1). + Returns: + paddle.Tensor: (batch_size * beam_size). + """ + beam_size = pred.size(-1) + finished = flag.repeat([1, beam_size]) + return pred.masked_fill_(finished, eos)