diff --git a/.gitignore b/.gitignore index e4134a082..cd2360e15 100644 --- a/.gitignore +++ b/.gitignore @@ -18,5 +18,7 @@ tools/sox-14.4.2 tools/soxbindings tools/montreal-forced-aligner/ tools/Montreal-Forced-Aligner/ +tools/sctk +tools/sctk-20159b5/ *output/ diff --git a/.notebook/Linear_test.ipynb b/.notebook/Linear_test.ipynb deleted file mode 100644 index 5c7370cf3..000000000 --- a/.notebook/Linear_test.ipynb +++ /dev/null @@ -1,605 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "academic-surname", - "metadata": {}, - "outputs": [], - "source": [ - "import paddle\n", - "from paddle import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "fundamental-treasure", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "L = nn.Linear(256, 2048)\n", - "L2 = nn.Linear(2048, 256)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "consolidated-elephant", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import torch\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "moderate-noise", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float64\n", - "Tensor(shape=[2, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[-1.54171216, -2.61531472, -1.79881978, ..., -0.31395876, 0.56513089, -0.44516513],\n", - " [-0.79492962, 1.91157901, 0.66567147, ..., 0.54825783, -1.01471853, -0.84924090],\n", - " [-1.22556651, -0.36225814, 0.65063190, ..., 0.65726501, 0.05563191, 0.09009409],\n", - " ...,\n", - " [ 0.38615900, -0.77905393, 0.99732304, ..., -1.38463700, -3.32365036, -1.31089687],\n", - " [ 0.05579993, 0.06885809, -1.66662002, ..., -0.23346378, -3.29372883, 1.30561364],\n", - " [ 1.90676069, 1.95093191, -0.28849599, ..., -0.06860496, 0.95347673, 1.00475824]],\n", - "\n", - " [[-0.91453546, 0.55298805, -1.06146812, ..., -0.86378336, 1.00454640, 1.26062179],\n", - " [ 0.10223761, 0.81301165, 2.36865163, ..., 0.16821407, 0.29240361, 1.05408621],\n", - " [-1.33196676, 1.94433689, 0.01934209, ..., 0.48036841, 0.51585966, 1.22893548],\n", - " ...,\n", - " [-0.19558455, -0.47075930, 0.90796155, ..., -1.28598249, -0.24321797, 0.17734711],\n", - " [ 0.89819717, -1.39516675, 0.17138045, ..., 2.39761519, 1.76364994, -0.52177650],\n", - " [ 0.94122332, -0.18581429, 1.36099780, ..., 0.67647684, -0.04699665, 1.51205540]]])\n", - "tensor([[[-1.5417, -2.6153, -1.7988, ..., -0.3140, 0.5651, -0.4452],\n", - " [-0.7949, 1.9116, 0.6657, ..., 0.5483, -1.0147, -0.8492],\n", - " [-1.2256, -0.3623, 0.6506, ..., 0.6573, 0.0556, 0.0901],\n", - " ...,\n", - " [ 0.3862, -0.7791, 0.9973, ..., -1.3846, -3.3237, -1.3109],\n", - " [ 0.0558, 0.0689, -1.6666, ..., -0.2335, -3.2937, 1.3056],\n", - " [ 1.9068, 1.9509, -0.2885, ..., -0.0686, 0.9535, 1.0048]],\n", - "\n", - " [[-0.9145, 0.5530, -1.0615, ..., -0.8638, 1.0045, 1.2606],\n", - " [ 0.1022, 0.8130, 2.3687, ..., 0.1682, 0.2924, 1.0541],\n", - " [-1.3320, 1.9443, 0.0193, ..., 0.4804, 0.5159, 1.2289],\n", - " ...,\n", - " [-0.1956, -0.4708, 0.9080, ..., -1.2860, -0.2432, 0.1773],\n", - " [ 0.8982, -1.3952, 0.1714, ..., 2.3976, 1.7636, -0.5218],\n", - " [ 0.9412, -0.1858, 1.3610, ..., 0.6765, -0.0470, 1.5121]]])\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "x = np.random.randn(2, 51, 256)\n", - "print(x.dtype)\n", - "px = paddle.to_tensor(x, dtype='float32')\n", - "tx = torch.tensor(x, dtype=torch.float32)\n", - "print(px)\n", - "print(tx)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cooked-progressive", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "mechanical-prisoner", - "metadata": {}, - "outputs": [], - "source": [ - "data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n", - "t_norm_ff = data['norm_ff']\n", - "t_ff_out = data['ff_out']\n", - "t_ff_l_x = data['ff_l_x']\n", - "t_ff_l_a_x = data['ff_l_a_x']\n", - "t_ff_l_a_l_x = data['ff_l_a_l_x']\n", - "t_ps = data['ps']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "indie-marriage", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "assured-zambia", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "L.set_state_dict({'weight': t_ps[0].T, 'bias': t_ps[1]})\n", - "L2.set_state_dict({'weight': t_ps[2].T, 'bias': t_ps[3]})\n", - "\n", - "ps = []\n", - "for n, p in L.named_parameters():\n", - " ps.append(p)\n", - "\n", - "for n, p in L2.state_dict().items():\n", - " ps.append(p)\n", - " \n", - "for p, tp in zip(ps, t_ps):\n", - " print(np.allclose(p.numpy(), tp.T))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "committed-jacob", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "extreme-traffic", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "optimum-milwaukee", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "viral-indian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "# data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n", - "# t_norm_ff = data['norm_ff']\n", - "# t_ff_out = data['ff_out']\n", - "# t_ff_l_x = data['ff_l_x']\n", - "# t_ff_l_a_x = data['ff_l_a_x']\n", - "# t_ff_l_a_l_x = data['ff_l_a_l_x']\n", - "# t_ps = data['ps']\n", - "TL = torch.nn.Linear(256, 2048)\n", - "TL2 = torch.nn.Linear(2048, 256)\n", - "TL.load_state_dict({'weight': torch.tensor(t_ps[0]), 'bias': torch.tensor(t_ps[1])})\n", - "TL2.load_state_dict({'weight': torch.tensor(t_ps[2]), 'bias': torch.tensor(t_ps[3])})\n", - "\n", - "# for n, p in TL.named_parameters():\n", - "# print(n, p)\n", - "# for n, p in TL2.named_parameters():\n", - "# print(n, p)\n", - "\n", - "ps = []\n", - "for n, p in TL.state_dict().items():\n", - " ps.append(p.data.numpy())\n", - " \n", - "for n, p in TL2.state_dict().items():\n", - " ps.append(p.data.numpy())\n", - " \n", - "for p, tp in zip(ps, t_ps):\n", - " print(np.allclose(p, tp))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "skilled-vietnamese", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[ 0.67277956 0.08313607 -0.62761104 ... -0.17480263 0.42718208\n", - " -0.5787626 ]\n", - " [ 0.91516656 0.5393416 1.7159258 ... 0.06144593 0.06486575\n", - " -0.03350811]\n", - " [ 0.438351 0.6227843 0.24096036 ... 1.0912522 -0.90929437\n", - " -1.012989 ]\n", - " ...\n", - " [ 0.68631977 0.14240924 0.10763275 ... -0.11513516 0.48065388\n", - " 0.04070369]\n", - " [-0.9525228 0.23197874 0.31264272 ... 0.5312439 0.18773697\n", - " -0.8450228 ]\n", - " [ 0.42024016 -0.04561988 0.54541194 ... -0.41933843 -0.00436018\n", - " -0.06663495]]\n", - "\n", - " [[-0.11638781 -0.33566502 -0.20887226 ... 0.17423287 -0.9195841\n", - " -0.8161046 ]\n", - " [-0.3469874 0.88269687 -0.11887559 ... -0.15566081 0.16357468\n", - " -0.20766167]\n", - " [-0.3847657 0.3984318 -0.06963477 ... -0.00360622 1.2360432\n", - " -0.26811332]\n", - " ...\n", - " [ 0.08230796 -0.46158582 0.54582864 ... 0.15747628 -0.44790155\n", - " 0.06020184]\n", - " [-0.8095085 0.43163058 -0.42837143 ... 0.8627463 0.90656304\n", - " 0.15847842]\n", - " [-1.485811 -0.18216592 -0.8882585 ... 0.32596245 0.7822631\n", - " -0.6460344 ]]]\n", - "[[[ 0.67278004 0.08313602 -0.6276114 ... -0.17480245 0.42718196\n", - " -0.5787625 ]\n", - " [ 0.91516703 0.5393413 1.7159253 ... 0.06144581 0.06486579\n", - " -0.03350812]\n", - " [ 0.43835106 0.62278455 0.24096027 ... 1.0912521 -0.9092943\n", - " -1.0129892 ]\n", - " ...\n", - " [ 0.6863195 0.14240888 0.10763284 ... -0.11513527 0.48065376\n", - " 0.04070365]\n", - " [-0.9525231 0.23197863 0.31264275 ... 0.53124386 0.18773702\n", - " -0.84502304]\n", - " [ 0.42024007 -0.04561983 0.545412 ... -0.41933888 -0.00436005\n", - " -0.066635 ]]\n", - "\n", - " [[-0.11638767 -0.33566508 -0.20887226 ... 0.17423296 -0.9195838\n", - " -0.8161046 ]\n", - " [-0.34698725 0.88269705 -0.11887549 ... -0.15566081 0.16357464\n", - " -0.20766166]\n", - " [-0.3847657 0.3984319 -0.06963488 ... -0.00360619 1.2360426\n", - " -0.26811326]\n", - " ...\n", - " [ 0.08230786 -0.4615857 0.5458287 ... 0.15747619 -0.44790167\n", - " 0.06020182]\n", - " [-0.8095083 0.4316307 -0.42837155 ... 0.862746 0.9065631\n", - " 0.15847899]\n", - " [-1.485811 -0.18216613 -0.8882584 ... 0.32596254 0.7822631\n", - " -0.6460344 ]]]\n", - "True\n", - "False\n" - ] - } - ], - "source": [ - "y = L(px)\n", - "print(y.numpy())\n", - "\n", - "ty = TL(tx)\n", - "print(ty.data.numpy())\n", - "print(np.allclose(px.numpy(), tx.detach().numpy()))\n", - "print(np.allclose(y.numpy(), ty.detach().numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "incorrect-allah", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "prostate-cameroon", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "governmental-surge", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 0.04476918 0.554463 -0.3027508 ... -0.49600336 0.3751858\n", - " 0.8254095 ]\n", - " [ 0.95594174 -0.29528382 -1.2899452 ... 0.43718258 0.05584608\n", - " -0.06974669]]\n", - "[[ 0.04476918 0.5544631 -0.3027507 ... -0.49600336 0.37518573\n", - " 0.8254096 ]\n", - " [ 0.95594174 -0.29528376 -1.2899454 ... 0.4371827 0.05584623\n", - " -0.0697467 ]]\n", - "True\n", - "False\n", - "True\n" - ] - } - ], - "source": [ - "x = np.random.randn(2, 256)\n", - "px = paddle.to_tensor(x, dtype='float32')\n", - "tx = torch.tensor(x, dtype=torch.float32)\n", - "y = L(px)\n", - "print(y.numpy())\n", - "ty = TL(tx)\n", - "print(ty.data.numpy())\n", - "print(np.allclose(px.numpy(), tx.detach().numpy()))\n", - "print(np.allclose(y.numpy(), ty.detach().numpy()))\n", - "print(np.allclose(y.numpy(), ty.detach().numpy(), atol=1e-5))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "confidential-jacket", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "improved-civilization", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "5e7e7c9fde8350084abf1898cf52651cfc84b17a\n" - ] - } - ], - "source": [ - "print(paddle.version.commit)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "d1e2d3b4", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " 'commit',\n", - " 'full_version',\n", - " 'istaged',\n", - " 'major',\n", - " 'minor',\n", - " 'mkl',\n", - " 'patch',\n", - " 'rc',\n", - " 'show',\n", - " 'with_mkl']" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(paddle.version)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "c880c719", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.1.0\n" - ] - } - ], - "source": [ - "print(paddle.version.full_version)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "f26977bf", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "commit: 5e7e7c9fde8350084abf1898cf52651cfc84b17a\n", - "None\n" - ] - } - ], - "source": [ - "print(paddle.version.show())" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "04ad47f6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.6.0\n" - ] - } - ], - "source": [ - "print(torch.__version__)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "e1e03830", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " '__version__',\n", - " 'cuda',\n", - " 'debug',\n", - " 'git_version',\n", - " 'hip']" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(torch.version)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "4ad0389b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'b31f58de6fa8bbda5353b3c77d9be4914399724d'" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.version.git_version" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "7870ea10", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'10.2'" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.version.cuda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "db8ee5a7", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6321ec2a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/WarmupLR.ipynb b/.notebook/WarmupLR.ipynb deleted file mode 100644 index 21abf9cbe..000000000 --- a/.notebook/WarmupLR.ipynb +++ /dev/null @@ -1,339 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "d6a0e098", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Union\n", - "\n", - "import torch\n", - "from torch.optim.lr_scheduler import _LRScheduler\n", - "\n", - "from typeguard import check_argument_types\n", - "\n", - "\n", - "class WarmupLR(_LRScheduler):\n", - " \"\"\"The WarmupLR scheduler\n", - " This scheduler is almost same as NoamLR Scheduler except for following\n", - " difference:\n", - " NoamLR:\n", - " lr = optimizer.lr * model_size ** -0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " WarmupLR:\n", - " lr = optimizer.lr * warmup_step ** 0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " Note that the maximum lr equals to optimizer.lr in this scheduler.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " optimizer: torch.optim.Optimizer,\n", - " warmup_steps: Union[int, float] = 25000,\n", - " last_epoch: int = -1,\n", - " ):\n", - " assert check_argument_types()\n", - " self.warmup_steps = warmup_steps\n", - "\n", - " # __init__() must be invoked before setting field\n", - " # because step() is also invoked in __init__()\n", - " super().__init__(optimizer, last_epoch)\n", - "\n", - " def __repr__(self):\n", - " return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n", - "\n", - " def get_lr(self):\n", - " step_num = self.last_epoch + 1\n", - " return [\n", - " lr\n", - " * self.warmup_steps ** 0.5\n", - " * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)\n", - " for lr in self.base_lrs\n", - " ]\n", - "\n", - " def set_step(self, step: int):\n", - " self.last_epoch = step" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "0d496677", - "metadata": {}, - "outputs": [], - "source": [ - "import torch.optim as optim\n", - "model = torch.nn.Linear(10, 200)\n", - "optimizer = optim.Adam(model.parameters())\n", - "scheduler = WarmupLR(optimizer, warmup_steps=25000)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "e3e3f3dc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0 0.0 -1\n" - ] - } - ], - "source": [ - "infos = {}\n", - "start_epoch = infos.get('epoch', -1) + 1\n", - "cv_loss = infos.get('cv_loss', 0.0)\n", - "step = infos.get('step', -1)\n", - "print(start_epoch, cv_loss, step)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "dc3d550c", - "metadata": {}, - "outputs": [], - "source": [ - "scheduler.set_step(step)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "e527634e", - "metadata": {}, - "outputs": [], - "source": [ - "lrs=[]\n", - "for i in range(100000):\n", - " scheduler.step()\n", - " lrs.append(scheduler.get_lr())" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "f1452db9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting matplotlib\n", - " Downloading matplotlib-3.4.1-cp38-cp38-manylinux1_x86_64.whl (10.3 MB)\n", - "\u001b[K |████████████████████████████████| 10.3 MB 575 kB/s eta 0:00:01\n", - "\u001b[?25hCollecting kiwisolver>=1.0.1\n", - " Downloading kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl (1.2 MB)\n", - "\u001b[K |████████████████████████████████| 1.2 MB 465 kB/s eta 0:00:01\n", - "\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (8.1.2)\n", - "Requirement already satisfied: numpy>=1.16 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (1.20.1)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.8.1)\n", - "Collecting cycler>=0.10\n", - " Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)\n", - "Requirement already satisfied: pyparsing>=2.2.1 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.4.7)\n", - "Requirement already satisfied: six in /workspace/wenet/venv/lib/python3.8/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n", - "Installing collected packages: kiwisolver, cycler, matplotlib\n", - "Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1\n" - ] - } - ], - "source": [ - "!pip install matplotlib\n", - "import matplotlib.pyplot as plt\n", - "\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "0f36d04f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqc0lEQVR4nO3deXxV1b338c8vCUkYkkAghJAEAhLQIJMEHHFCBa2KVkG0T7Wt1qet9ra1w9Xn3ufe1ld7b21tvVq1alut+mhJQK3Yqjig1SpCDgIyBiLTSZhCAglTyLSeP86GxjTDQZKc6ft+vXh5zjrrrLM2O+bL3mvv3zHnHCIiIu2JC/UEREQkvCkoRESkQwoKERHpkIJCREQ6pKAQEZEOJYR6Al1h0KBBLi8vL9TTEBGJKMuXL9/rnMvorF9UBEVeXh4+ny/U0xARiShmti2Yfjr1JCIiHVJQiIhIhxQUIiLSIQWFiIh0SEEhIiIdCioozGymmZWaWZmZ3d3G60lmVuS9vtTM8lq8do/XXmpmM1q0P2lme8xsTaux0s3sTTPb5P13wElsn4iInKROg8LM4oFHgMuBAuBGMyto1e1WYJ9zbhTwAHCf994CYC4wFpgJPOqNB/BHr621u4G3nXP5wNvecxERCZFgjiimAmXOuc3OuXpgHjCrVZ9ZwNPe4wXAdDMzr32ec+6oc24LUOaNh3PuPaC6jc9rOdbTwDXBb450p82VB3m3dE+opyEiPSyYoMgG/C2el3ttbfZxzjUCNcDAIN/bWqZzbqf3eBeQ2VYnM7vdzHxm5qusrAxiM+Rk3fS7pXzlqRLeWrc71FMRkR4U1ovZLvCtSm1+s5Jz7gnnXKFzrjAjo9M70OUkle05wK7aOgC+V7SSTysPhnhGItJTggmKCiC3xfMcr63NPmaWAKQBVUG+t7XdZpbljZUF6FxHGCj2lZMQZ7xy53kkJsRx+zM+DtQ1hHpaItIDggmKEiDfzEaYWSKBxemFrfosBG7xHl8PLPaOBhYCc72rokYA+cCyTj6v5Vi3AC8HMUfpRg1Nzbz4cTnTTxvMuJw0Hr7pDLZWHeZ7RatobtZX6YpEu06DwltzuBNYBKwHip1za83sXjO72uv2B2CgmZUBd+FdqeScWwsUA+uA14E7nHNNAGb2J2AJMMbMys3sVm+snwOXmtkm4BLvuYTQ4g172HuwnjmFgYPDs08ZyL9/4TTeWr+bB9/eFOLZiUh3s8A//CNbYWGhU/XY7nPb0yV8Ul7Dh3dfTEJ84N8Wzjl+uOATFiwv58G5E5k1sbNrFEQk3JjZcudcYWf9wnoxW0JvT20d75RWct3knOMhAWBm/Oza0zlzRDo/nP8JJVvbutJZRKKBgkI6tODjcpqa3fHTTi0lJcTz+Jcnk5Pem68/42PL3kMhmKGIdDcFhbTLOcd8XzlT89IZMahvm33690nkqa9MIc6Mrz61jOpD9T08SxHpbgoKaVfJ1n1s2XuIOVP++WiipeED+/K7myezo6aOrz/j40h9Uw/NUER6goJC2lXs89MvKYErxg3ptO/k4ek8eMNEVmzfxzefW059Y3MPzFBEeoKCQtp0oK6Bv36yk6smZNEnMbivVr98XBY/u3Yc75ZW8oP5usdCJFoE9xtAYs5fP9nJkYamNhexO3Lj1GHsP9zAfa9vIK13L+6dNZZAfUgRiVQKCmlTkc9P/uB+TMztf8Lv/eaFp7D/cD2Pv7eZ/n168f3LxnT9BEWkxygo5J9s2n2AFdv38+9fOO1zHw3cffmp1Bxp4DeLy0iMj+Pb0/O7eJYi0lMUFPJPin1+EuKMayZ9/rutAzfkjaO+sZlfvbkRM7jzYoWFSCRSUMhnBAoAVnDJaZkM6pd0UmPFxxm/nD0BgPvf2AgoLEQikYJCPuPt9XuoOlTPnCk5XTJe67AwM+64aFSXjC0iPUNBIZ8x3+cnMzWJ8/O77sugjoWFA365qJT6xma+e0m+roYSiRAKCjlud20d75Tu4RsXnPKZAoBdIT7OuH/2BBLijAff3kTNkQb+48oC4uIUFiLhTkEhxy1YXk6z44TvnQhWfJxx33XjSUnuxZMfbOFAXSP3XTeuy0NJRLqWgkKAYwUA/UwdkU5eOwUAu0JcnPF/rzyNtN69eOCtjRyoa+A3N00iKSG+2z5TRE6O/iknACzbUs3WqsPc0E1HEy2ZGd+5JJ//vKqAN9bt5qtPlVCr798WCVsKCgGg2FfuFQDM6rHP/Oq5I/j1nAks21LN9b/9kIr9R3rss0UkeAoK4UBdA6+u3slVE4bSO7FnTwF98Ywcnv7aVHbur+PaRz5gTUVNj36+iHROQSH8xSsAeEMn3zvRXc4dNYgF3zyHXvFxzHl8Ce9s2BOSeYhI2xQUQlGJn9GZ/ZiQkxayOYwZksJL3zqHkRl9ufXpEp5ZshXnVKZcJBwoKGLcxt0HWOnfz5zC3JDfADc4NZmi28/mojGD+Y+X13LPi6s52qhvyxMJNQVFjCsu8dMr3rj2JAoAdqW+SQk8cXMhd1x0CvNK/Nz0u6Xsqa0L9bREYpqCIobVNzbz0opAAcCBJ1kAsCvFxxk/nHEqj9x0But21HLVw39npX9/qKclErMUFDFs8YbdgQKAPXDvxOfxhfFZvPDNc0iICyxyF5f4Qz0lkZikoIhhxb5yhqQmc/7orisA2NUKhqbyyrfPo3D4AH70wid8v3gVh+sbQz0tkZiioIhRu2rqeLd0D9dNziY+zAvzpfdN5Nlbz+Rfpufz4opyZj38AWV7DoR6WiIxQ0ERo174OFAAcPbk8Dzt1Fp8nHHXpaN55mtTqT5Uz1W/+YCXVpSHeloiMUFBEYOccxT7/JzZzQUAu8O0/Axe/c40xuWk8b2iVfxowSoOHdWpKJHupKCIQUu3VLOt6nDI7sQ+WZmpyTx/25ncedEo5i8v54qH3mf5tn2hnpZI1FJQxKBin5+UpAQuP73nCgB2tYT4OH4wYwxFt59NY5Nj9mMf8us3N9LQ1BzqqYlEnaCCwsxmmlmpmZWZ2d1tvJ5kZkXe60vNLK/Fa/d47aVmNqOzMc1supl9bGYrzezvZqYvWO5CtccKAE7s+QKA3WHqiHRe++40rpmUzUNvb2L2Y0vYsvdQqKclElU6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5m+BLznnJgLPA/9+Ulson/GXVTupa2juke+d6Cmpyb349ZyJPHzTJLbsPcQVD77PUx9soblZtaJEukIwRxRTgTLn3GbnXD0wD5jVqs8s4Gnv8QJgugUKB80C5jnnjjrntgBl3ngdjemAVO9xGrDj822atKXI52dMZgrjQ1gAsLtcOX4or393GmeOTOcnr6xjzuNL+LTyYKinJRLxggmKbKDlLbHlXlubfZxzjUANMLCD93Y05m3Aq2ZWDnwZ+HlbkzKz283MZ2a+ysrKIDZDSncdYJV/P3OmhL4AYHfJSuvNU1+Zwq9mT2DTnoNc/uD7/PbdT2nU2oXI5xaOi9nfA65wzuUATwG/bquTc+4J51yhc64wIyN87ywOJ8W+8CoA2F3MjOsm5/DmXedz0ZgM7nt9A9c++iHrd9aGemoiESmYoKgAWp7QzvHa2uxjZgkEThlVdfDeNtvNLAOY4Jxb6rUXAecEtSXSoWMFAC8tyCS9b2Kop9MjBqck89j/mswjN53Bjv1HuPI3f+e/Xl2v+y5ETlAwQVEC5JvZCDNLJLA4vbBVn4XALd7j64HFLvCtMwuBud5VUSOAfGBZB2PuA9LMbLQ31qXA+s+/eXLM2+t3U32ontlRtIgdDDPjC+OzeOuuC5hTmMMT721m+q/+xmurd+qLkUSClNBZB+dco5ndCSwC4oEnnXNrzexewOecWwj8AXjWzMqAagK/+PH6FQPrgEbgDudcE0BbY3rtXwdeMLNmAsHxtS7d4hhV5PMHCgDmx+ZpugF9E/nvL45ndmEu//bSGr753MdcMDqDn1w9NuLuThfpaRYN/6oqLCx0Pp8v1NMIWztrjnDuzxfzrQtH8YMZY0I9nZBrbGrm2Y+28as3NlLf1Mw3LjiFb1wwkj6Jnf67SSSqmNly51xhZ/3CcTFbutgLy70CgIU5oZ5KWEiIj+Or547g7e9fwMyxQ3jo7U1cfP/fePHjct17IdIGBUWUa252FPvKOWtkOsMH6hRLS5mpyTx04yQWfONsMlOTuKt4Fdc8+gG+rdWhnppIWFFQRLmlW6rZXh25BQB7QmFeOi9961weuGECe2qPcv1jS7jj+Y/xVx8O9dREwoJOyka5+T4/KcmRXQCwJ8TFGddOymHG2CE8/rfNPP7ep7y5djdfOmsYd1w0ikFh9J3iIj1NRxRRrLaugVfX7OTqCUNJ7hX5BQB7Qp/EBL536Wje+cGFXDspm6c/3MoFv3iHX7+5kQN1DaGenkhIKCii2CurdgQKAOq00wnLSuvNfdeP543vXcAFYzJ46O1NnP+Ld/j9+5upa2gK9fREepSCIooVl/g5dUgK47KjrwBgTxk1uB+PfmkyC+88l9Oz0/jpX9dz0f3v8vzS7dQ3qn6UxAYFRZTasKuWVeU1zCmM3gKAPWl8Tn+evfVMnr/tTDJTk/k/L63mwl++w7NLtnK0UUcYEt0UFFGquKScXvHGNVFeALCnnTNqEC996xye+dpUsvr35v++vJYLfvEuf/xgi05JSdRSUEShQAHAci4rGBIzBQB7kplx/ugMFnzjbJ677UyGpffhx6+sY5q3hnG4XkUHJbro8tgo9Nb63ew73KA7sbuZmXHuqEGcO2oQH22u4qG3N/HTv67n4XfKuPms4dx8Tp4uq5WooKCIQkUlfrLSkpkWowUAQ+GskQM5a+RAlm+r5rG/beahxWU8/t5mrp+cw9enjVThQYloCooos2P/Ed7bVMmdF40iPk6L2D1t8vB0fndzOmV7DvL79zcz31fO88u2M3PsEG4/fySThg0I9RRFTpiCIsq8sLwc52D2ZN07EUqjBvfj59eN565LR/PHD7fy/z7axmtrdjE1L51bzsnjsrGZ9IrXEqFEBpUZjyLNzY4L73+XnAG9ef7rZ4V6OtLCwaONzFu2naeXbMVffYQhqcl8+ezhzJ2Sy0CtY0iIqMx4DPpoSxXbqw8zJ8a+xS4S9EtK4LZpI3n3Bxfxu5sLGTW4H79cVMrZP1/M94tXsbq8JtRTFGmXTj1Fkfm+clKSE5h5+pBQT0XaER9nXFqQyaUFmZTtOcDTH27jhY/LeeHjcs4Y1p8vnz2cy0/PUm0uCSs69RQlao40MPVnbzG7MIefXjMu1NORE1Bb18ACXznPLNnK1qrDpPXuxbWTsrlx6jDGDEkJ9fQkigV76klHFFHilVU7ONrYzA2Fw0I9FTlBqcm9+Np5I/jKOXl8tLmK55dt57ml2/jjh1s5Y1h/bpw6jCvHD6V3oo4yJDR0RBElrn7479Q3NvPad6aptlMUqDp4lBc/ruBPJdvZXHmIlOQErpmYzQ1Tchk7NFX7WLqEjihiyPqdtXxSXsN/XlWgXyBRYmC/JL5+/khumzaCZVuq+dOy7RT5/Dz70TbGZKZw3eRsZk3MJjM1OdRTlRigoIgCxT4/ifFxXDNRBQCjjZlx5siBnDlyID8+XM8rn+zkxY/L+a9XN/Dz1zZwXn4G152RzWUFQ3RqSrqNgiLCHW1s4s8rKrh0bCYDVAAwqvXvk8iXzxrOl88azqeVB3np4wpeWlHBd+atpF9SAl8Yl8UXz8hmSl46cborX7qQgiLCvbVuD/sON+jeiRhzSkY/fjBjDHddOpqPtlTx4scV/OWTHRT5AnW+rhyfxZXjhzI+J02nI+WkaTE7wt385DLKdh/g/X+9WLWdYtzh+kbeWLubv3yyg79trKShyTEsvQ9XTcjiqglDGZOZotCQz9BidgzYsf8I72+q5NsqAChAn8QErpmUzTWTsqk53MCitbt45ZMdPPa3zTzyzqeMGtyPq8YP5coJWZyS0S/U05UIoqCIYAuOFQDUaSdpJa1PL+ZMyWXOlFz2HjzKa2t28ZdVO/iftzfywFsbOXVICpeNHcKMsZkUZOlyW+mYTj1FqOZmxwX3v8Ow9D48d5sKAEpwdtXU8erqnby+dhe+rdU0O8hN782MgiHMOH0IZwwboKPTGKJTT1Huo81V+KuP8IPLxoR6KhJBhqQl87XzRvC180aw9+BR3lq3m0Vrd/HMkm38/u9bGNQviUsLMpl5+hDOHjmQxATVDRUFRcQq9vlJTU5gxlgVAJTPZ1C/JOZOHcbcqcM4UNfAO6WVLFq7i5dXVvCnZdtJSUrg/NEZXHTqYC4ck6GvdY1hQQWFmc0EHgTigd87537e6vUk4BlgMlAF3OCc2+q9dg9wK9AE/ItzblFHY1rgZOlPgdnee37rnHvo5DYzutQcaeC1NbuYU5irKqPSJVKSe3H1hKFcPWEodQ1NfFC2lzfW7uad0j38dfVOzGBCTn+mnzqYi04drDIiMabToDCzeOAR4FKgHCgxs4XOuXUtut0K7HPOjTKzucB9wA1mVgDMBcYCQ4G3zGy09572xvwKkAuc6pxrNrPBXbGh0WThsQKAU7SILV0vuVc800/LZPppmTQ3O9btrOXt9XtYXLqHX725kV+9uZEhqclcdGoGF5+aybmjBtInUScnolkwe3cqUOac2wxgZvOAWUDLoJgF/Nh7vAB42DsymAXMc84dBbaYWZk3Hh2M+U3gJudcM4Bzbs/n37zoVFzi57SsVMYOTQ31VCTKxcUZp2encXp2Gt+5JJ89B+p4t7SSdzbsYeHKHfxpmZ/EhDim5qUzLX8Q54/O4NQhul8j2gQTFNmAv8XzcuDM9vo45xrNrAYY6LV/1Oq9xwoStTfmKQSORq4FKgmcrtrUelJmdjtwO8CwYbFTWnvdjlpWV9TwYxUAlBAYnJLMnMJc5hTmUt/YTMnWahZv2MP7myr579c28N+vbSAjJYlpowYxbfQgzhuVQUaK1jYiXTgeLyYBdc65QjP7IvAkMK11J+fcE8ATELg8tmenGDrHCgDOUgFACbHEhDjOHTWIc0cNAgKX3r6/qZL3N+3l3Y2VvLiiAoDTslI5P38Q0/IzKMwboHW1CBRMUFQQWDM4Jsdra6tPuZklAGkEFrU7em977eXAi97jl4CngphjTDja2MSfV1ZwmQoAShgakpbM7MJcZhfmHl/beG9TJe9v3MuTH2zh8fc2k5QQR2HeAM4eOZCzRg5kfE5/XYIbAYIJihIg38xGEPhlPhe4qVWfhcAtwBLgemCxc86Z2ULgeTP7NYHF7HxgGWAdjPln4CJgC3ABsPFzb12UeXPdbvarAKBEgJZrG9+6cBSHjjaybEs172/ay5LNVdz/RuB/69694gPBccpAzh45kHHZaSTEKzjCTadB4a053AksInAp65POubVmdi/gc84tBP4APOstVlcT+MWP16+YwCJ1I3CHc64JoK0xvY/8OfCcmX0POAjc1nWbG9mKSvxk9+99/FBfJFL0TUrgIu/SWoB9h+pZuqWKJZ9WsWRzFb94vRSAfkkJTDkeHIMoGJqqO8XDgEp4RIiK/Uc4777FfPvifO66dHTnbxCJIHsPHuWjzf8Ijs2VhwBISUrgjOEDmJI3gMK8dCbm9tcaRxdSCY8os8BXDsDsyTkhnolI1xvUL4krxw/lyvFDAdhdW8dHm6tYtqUa39Z9x09V9YoPnNKakpdO4fBAeKRrva7bKSgiQHOzY/5yP+eeMojc9D6hno5It8tMTWbWxOzjV/ftP1zP8m37KNm6D9/Wav74wVaeeG8zAKdk9A0Ehxcewwf20aXjXUxBEQGWbK6ifN8RfjhDBQAlNvXvk3j8bnGAuoYmVlfUULI1cMTx6uqdzCsJ3JqV3jeRibn9mZTbn0nDBjA+N43U5F6hnH7EU1BEABUAFPms5F7xTMlLZ0peOhA46t605yC+bdWs3L6fFf79LN4QKOpgBqMy+gXCY9gAJub2Z3RmP11ddQIUFGGu5nCgAODcKSoAKNKeuDhjzJAUxgxJ4UtnDgcCxTM/Kd/Piu37Wenfz1vrdzN/eWCtr09iPONz0piYGwiOCblpDElN1imrdigowtzCVRXUNzbr3gmRE5TWuxfT8jOYlp8BgHOObVWHWenfz4rt+1jp38/v399MY3Pgys9B/RI5PTuN8d79H+NyFB7HKCjCXJHPT0FWKqdnp4V6KiIRzczIG9SXvEF9uWZSYJG8rqGJtTtqWVNRw+qKGlaX1/Dexkq87GBQvyTGZacyLqc/47LTGJ+TRmZqcgi3IjQUFGFs7Y4a1lTU8pOrx4Z6KiJRKblXPJOHD2Dy8AHH247UN7FuZyA0VlfUsrpiP39rER4ZKUmM8446CrwqzjkDekf1kYeCIozN95WTmBDHrIlDQz0VkZjROzGeycPTmTw8/Xjb4fpG1u+s5ZPywJHHmooa3i3dczw8UpISOC0rldOyUjgtK5WCoamMzkyJmnVFBUWYqmto4qUVFcwYO4T+fXRDkUgo9UlM+KfwOFLfROnuA6zbUcv6nbWs21nLguXlHKpvAiDO4JSMfseD41iQDE6JvFNXCoow9ea63dQcaWBOoe7EFglHvRPjmZjbn4m5/Y+3NTc7tlcfZv3Of4TH8m37WLhqx/E+g/olcVpWCmMyUxg9JIXRmSnkD+5H36Tw/XUcvjOLccU+rwDgKSoAKBIp4uL+sWB++bis4+37D9ezfueB4+Gxfmctz360jaONzcf75Kb3ZkxmCvmZXohkpnDK4L4kJYT+9JWCIgyV7zvM38v28p3p+cSpcqZIxOvfJzFQEfeUgcfbmpod/urDlO4+wMZdByjdfYBNuw/ybmnl8Ut24+OMvIF9GO0Fx5ghKYzO7EfewL49esOggiIMLfBuCrpeBQBFolZ8i6OPllUX6hub2Vp1iI0tAmTDrgMsWrvr+OJ5YnwcIwb1ZdTgftx9+andXgNOQRFmmpsd833lnDdqEDkDVABQJNYkJsQdP4Jg/D/a6xqaKNtzMBAguw9Stucga3fU9Mg3BCoowsyHn1ZRsf8I/3r5qaGeioiEkeRe8ce/NbCnqSpWmCn2+Unr3YvLCjJDPRUREUBBEVZqDjfw+tpdXDNxaNTcqCMikU9BEUZePlYAcIoKAIpI+FBQhJGiEj9jh6YydqgKAIpI+FBQhIk1FTWs3VHLDTqaEJEwo6AIE/N9/kABwAnZoZ6KiMhnKCjCQF1DE39euYOZY4eQ1kff7Ssi4UVBEQbeOF4AUKedRCT8KCjCQHGJn5wBvTmnRR0YEZFwoaAIMX/1YT74dC+zJ+eqAKCIhCUFRYgdLwCo750QkTCloAih5mbHguWBAoDZ/XuHejoiIm1SUITQB5/upWL/ES1ii0hYU1CEULGvnP59enHZWBUAFJHwpaAIkf2H61m0dhfXTMwOi686FBFpT1BBYWYzzazUzMrM7O42Xk8ysyLv9aVmltfitXu89lIzm3ECYz5kZgc/53aFvZdX7ggUANRpJxEJc50GhZnFA48AlwMFwI1mVtCq263APufcKOAB4D7vvQXAXGAsMBN41MziOxvTzAqBASe5bWGtqMTP6dmpFAxNDfVUREQ6FMwRxVSgzDm32TlXD8wDZrXqMwt42nu8AJhuZua1z3POHXXObQHKvPHaHdMLkV8CPzq5TQtfaypqWLezlht0NCEiESCYoMgG/C2el3ttbfZxzjUCNcDADt7b0Zh3Agudczs7mpSZ3W5mPjPzVVZWBrEZ4aPYKwB4tQoAikgECKvFbDMbCswGftNZX+fcE865QudcYUZGRvdProvUNTTx5xUVXH66CgCKSGQIJigqgJbnSHK8tjb7mFkCkAZUdfDe9tonAaOAMjPbCvQxs7IgtyUiLFq7i9q6Ri1ii0jECCYoSoB8MxthZokEFqcXtuqzELjFe3w9sNg557z2ud5VUSOAfGBZe2M65/7qnBvinMtzzuUBh70F8qhR7POTm96bs0eqAKCIRIaEzjo45xrN7E5gERAPPOmcW2tm9wI+59xC4A/As96//qsJ/OLH61cMrAMagTucc00AbY3Z9ZsXXvzVh/mgrIq7Lh2tAoAiEjE6DQoA59yrwKut2v6jxeM6AmsLbb33Z8DPghmzjT79gplfpJi/vBwzuG6yCgCKSOQIq8XsaNbU7Fjg8zMtP0MFAEUkoigoesgHZXvZUVPHHJUTF5EIo6DoIcU+P/379OLSAhUAFJHIoqDoAfsO1fPG2t0qACgiEUlB0QNeXllBfZMKAIpIZFJQdDPnHEW+csZlp6kAoIhEJAVFN1tTUcv6nbXMmaKjCRGJTAqKblbs85OUEMfVE4aGeioiIp+LgqIb1TU08eeVXgHA3ioAKCKRSUHRjRat3cWBukaddhKRiKag6EZFJYECgGeNUAFAEYlcCopu4q8+zIefVjFncq4KAIpIRFNQdJP5Pr8KAIpIVFBQdIOmZseC5eWcn5/BUBUAFJEIp6DoBn8/XgBQi9giEvkUFN2g2OdnQJ9eXFIwONRTERE5aQqKLrbvUD1vrt3NNZNUAFBEooOCoou9tCJQAPAG3TshIlFCQdGFnHMU+/yMz0nj1CEqACgi0UFB0YVWV9SwYdcBLWKLSFRRUHShYwUAr1IBQBGJIgqKLlLX0MTLK3dwxbgsFQAUkaiioOgir6/xCgDqtJOIRBkFRRcpKvEzLL0PZ45ID/VURES6lIKiC2yvOsySzVXMKcxRAUARiToKii4wf7mfOBUAFJEopaA4SccLAI7OICtNBQBFJPooKE7S+5sq2akCgCISxRQUJ2m+r5z0volcclpmqKciItItFBQnofpQPW+s28U1E7NJTNBfpYhEp6B+u5nZTDMrNbMyM7u7jdeTzKzIe32pmeW1eO0er73UzGZ0NqaZPee1rzGzJ80sbO9ee2lFBQ1NTgUARSSqdRoUZhYPPAJcDhQAN5pZQatutwL7nHOjgAeA+7z3FgBzgbHATOBRM4vvZMzngFOBcUBv4LaT2sJu4pxjvs/PhJw0xgxJCfV0RES6TTBHFFOBMufcZudcPTAPmNWqzyzgae/xAmC6mZnXPs85d9Q5twUo88Zrd0zn3KvOAywDwvKa00/KvQKAOpoQkSgXTFBkA/4Wz8u9tjb7OOcagRpgYAfv7XRM75TTl4HX25qUmd1uZj4z81VWVgaxGV2r2OcnuZcKAIpI9AvnFdhHgfecc++39aJz7gnnXKFzrjAjI6NHJ3akvomFK3dwxelZpCaH7RKKiEiXSAiiTwXQ8vxKjtfWVp9yM0sA0oCqTt7b7phm9p9ABvC/g5hfj3t97U4OHG3UaScRiQnBHFGUAPlmNsLMEgksTi9s1WchcIv3+HpgsbfGsBCY610VNQLIJ7Du0O6YZnYbMAO40TnXfHKb1z2KSvwMH6gCgCISGzo9onDONZrZncAiIB540jm31szuBXzOuYXAH4BnzawMqCbwix+vXzGwDmgE7nDONQG0Nab3kY8B24AlgfVwXnTO3dtlW3yStlUd4qPN1fxwxhi8+YmIRLVgTj3hnHsVeLVV23+0eFwHzG7nvT8DfhbMmF57UHMKlfm+8kABwDPC8mIsEZEuF86L2WHnWAHAC0ZnMCQtOdTTERHpEQqKE/Depkp21aoAoIjEFgXFCZjv85PeN5HpKgAoIjFEQRGkqoNHeXPdbq6dpAKAIhJb9BsvSMcKAOq0k4jEGgVFEJxzFPv8TMjtrwKAIhJzFBRBWFVew8bdB7lBRxMiEoMUFEH4RwHArFBPRUSkxykoOnGkvolXVu7ginFZpKgAoIjEIAVFJ15bEygAqNNOIhKrFBSdKCrxkzewD1NVAFBEYpSCogNb9x5i6ZZqZhfmqgCgiMQsBUUH5i/3qwCgiMQ8BUU7jhUAvHDMYBUAFJGYpqBox3sbK9lde5Q5hTqaEJHYpqBoR1GJn4F9E7n4VBUAFJHYpqBoQ9XBo7y1XgUARURAQdGml1ZU0NjsmDNF906IiCgoWnHOUVTiZ2Juf0ZnqgCgiIiCopWV/v1s2nOQG3Q0ISICKCj+SbGvnN694rlyvAoAioiAguIzDtc38soqFQAUEWlJQdHCa6t3cfBoo047iYi0oKBoocjnZ8SgvkzJGxDqqYiIhA0FhWfL3kMs21LN7MIcFQAUEWlBQeGZ71MBQBGRtigogMamZl74uJyLxgwmM1UFAEVEWlJQAO9tChQAnK1vsRMR+ScKCgIFAAf1S2T6aYNDPRURkbAT80Gx9+BR3l6/h2snZdMrPub/OkRE/knM/2Z86eNAAUDdOyEi0raggsLMZppZqZmVmdndbbyeZGZF3utLzSyvxWv3eO2lZjajszHNbIQ3Rpk3ZuJJbmO7nHMU+/ycMaw/owarAKCISFs6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5n3AA95Y+7yxu8UKrwDgHC1ii4i0K5gjiqlAmXNus3OuHpgHzGrVZxbwtPd4ATDdAnetzQLmOeeOOue2AGXeeG2O6b3nYm8MvDGv+dxb14n5Pn+gAOCEod31ESIiES+YoMgG/C2el3ttbfZxzjUCNcDADt7bXvtAYL83RnufBYCZ3W5mPjPzVVZWBrEZ/2xYel++cm4e/ZISPtf7RURiQcT+hnTOPQE8AVBYWOg+zxjfvPCULp2TiEg0CuaIogJoeRI/x2trs4+ZJQBpQFUH722vvQro743R3meJiEgPCiYoSoB872qkRAKL0wtb9VkI3OI9vh5Y7JxzXvtc76qoEUA+sKy9Mb33vOONgTfmy59/80RE5GR1eurJOddoZncCi4B44Enn3FozuxfwOecWAn8AnjWzMqCawC9+vH7FwDqgEbjDOdcE0NaY3kf+KzDPzH4KrPDGFhGRELHAP+IjW2FhofP5fKGehohIRDGz5c65ws76xfyd2SIi0jEFhYiIdEhBISIiHVJQiIhIh6JiMdvMKoFtn/Ptg4C9XTidSKBtjg3a5uh3sts73DmX0VmnqAiKk2FmvmBW/aOJtjk2aJujX09tr049iYhIhxQUIiLSIQWFV1gwxmibY4O2Ofr1yPbG/BqFiIh0TEcUIiLSIQWFiIh0KKaDwsxmmlmpmZWZ2d2hns+JMLNcM3vHzNaZ2Voz+47Xnm5mb5rZJu+/A7x2M7OHvG39xMzOaDHWLV7/TWZ2S4v2yWa22nvPQ95X1Yac973rK8zsL97zEWa21JtnkVe6Hq+8fZHXvtTM8lqMcY/XXmpmM1q0h93PhJn1N7MFZrbBzNab2dnRvp/N7Hvez/UaM/uTmSVH2342syfNbI+ZrWnR1u37tb3P6JBzLib/EChv/ikwEkgEVgEFoZ7XCcw/CzjDe5wCbAQKgF8Ad3vtdwP3eY+vAF4DDDgLWOq1pwObvf8O8B4P8F5b5vU1772Xh3q7vXndBTwP/MV7XgzM9R4/BnzTe/wt4DHv8VygyHtc4O3vJGCE93MQH64/EwS+O/4273Ei0D+a9zOBrz/eAvRusX+/Em37GTgfOANY06Kt2/dre5/R4VxD/T9BCH8YzwYWtXh+D3BPqOd1EtvzMnApUApkeW1ZQKn3+HHgxhb9S73XbwQeb9H+uNeWBWxo0f6ZfiHczhzgbeBi4C/e/wR7gYTW+5XA952c7T1O8PpZ6319rF84/kwQ+LbILXgXnrTef9G4nwkEhd/75Zfg7ecZ0bifgTw+GxTdvl/b+4yO/sTyqadjP4zHlHttEcc71J4ELAUynXM7vZd2AZne4/a2t6P28jbaQ+1/gB8Bzd7zgcB+51yj97zlPI9vm/d6jdf/RP8uQmkEUAk85Z1u+72Z9SWK97NzrgK4H9gO7CSw35YT3fv5mJ7Yr+19RrtiOSiigpn1A14Avuucq235mgv8kyFqrn82syuBPc655aGeSw9KIHB64rfOuUnAIQKnC46Lwv08AJhFICSHAn2BmSGdVAj0xH4N9jNiOSgqgNwWz3O8tohhZr0IhMRzzrkXvebdZpblvZ4F7PHa29vejtpz2mgPpXOBq81sKzCPwOmnB4H+Znbsa31bzvP4tnmvpwFVnPjfRSiVA+XOuaXe8wUEgiOa9/MlwBbnXKVzrgF4kcC+j+b9fExP7Nf2PqNdsRwUJUC+dyVFIoFFsIUhnlPQvCsY/gCsd879usVLC4FjVz7cQmDt4lj7zd7VE2cBNd7h5yLgMjMb4P1L7jIC5293ArVmdpb3WTe3GCsknHP3OOdynHN5BPbXYufcl4B3gOu9bq23+djfxfVef+e1z/WulhkB5BNY+Au7nwnn3C7Ab2ZjvKbpBL6DPmr3M4FTTmeZWR9vTse2OWr3cws9sV/b+4z2hXLRKtR/CFxJsJHAFRD/Fur5nODczyNwyPgJsNL7cwWBc7NvA5uAt4B0r78Bj3jbuhoobDHW14Ay789XW7QXAmu89zxMqwXVEG//hfzjqqeRBH4BlAHzgSSvPdl7Xua9PrLF+//N265SWlzlE44/E8BEwOft6z8TuLolqvcz8BNggzevZwlcuRRV+xn4E4E1mAYCR4639sR+be8zOvqjEh4iItKhWD71JCIiQVBQiIhIhxQUIiLSIQWFiIh0SEEhIiIdUlCIiEiHFBQiItKh/w/uhegfvR+Q7QAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "xs = list(range(100000))\n", - "plt.plot(xs, lrs)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "4f4e282c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/wenet/venv/lib/python3.8/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "from typing import Union\n", - "\n", - "from paddle.optimizer.lr import LRScheduler\n", - "from typeguard import check_argument_types\n", - "\n", - "class WarmupLR(LRScheduler):\n", - " \"\"\"The WarmupLR scheduler\n", - " This scheduler is almost same as NoamLR Scheduler except for following\n", - " difference:\n", - " NoamLR:\n", - " lr = optimizer.lr * model_size ** -0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " WarmupLR:\n", - " lr = optimizer.lr * warmup_step ** 0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " Note that the maximum lr equals to optimizer.lr in this scheduler.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " warmup_steps: Union[int, float]=25000,\n", - " learning_rate=1.0,\n", - " last_epoch=-1,\n", - " verbose=False):\n", - " assert check_argument_types()\n", - " self.warmup_steps = warmup_steps\n", - " super().__init__(learning_rate, last_epoch, verbose)\n", - "\n", - " def __repr__(self):\n", - " return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n", - "\n", - " def get_lr(self):\n", - " step_num = self.last_epoch + 1\n", - " return self.base_lr * self.warmup_steps**0.5 * min(\n", - " step_num**-0.5, step_num * self.warmup_steps**-1.5)\n", - "\n", - " def set_step(self, step: int):\n", - " self.step(step)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "8c40b202", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-1\n" - ] - } - ], - "source": [ - "sc = WarmupLR(warmup_steps=25000, learning_rate=0.001)\n", - "print(step)\n", - "#sc.set_step(step)\n", - "sc.set_step(0)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "ecbc7e37", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqaUlEQVR4nO3de3xU9Z3/8dcnCUm4JIGEEAIBEiCAQW4SEG94F7QqagGhu9Varb9a3W51267+tr/dtrvdVevW1VardrVaa4WAN7QqKqJ4QchwvwYiAZMQICQQ7uT2/f0xB4xpLoMkmcnM+/l48GDmO99z5ns4Yd4553vOZ8w5h4iISHOigj0AEREJbQoKERFpkYJCRERapKAQEZEWKShERKRFMcEeQFvo3bu3y8zMDPYwREQ6lRUrVux1zqW21i8sgiIzMxOfzxfsYYiIdCpmtiOQfjr1JCIiLVJQiIhIixQUIiLSIgWFiIi0SEEhIiItCigozGyqmRWYWaGZ3dvE63FmNtd7fZmZZTZ47T6vvcDMpjRof8bM9pjZ+kbrSjazd81sq/d3r9PYPhEROU2tBoWZRQOPAVcCOcBsM8tp1O1WYJ9zbijwMPCAt2wOMAsYCUwFHvfWB/Cs19bYvcAi51w2sMh7LiIiQRLIEcVEoNA5t805Vw3MAaY16jMNeM57PB+41MzMa5/jnDvunCsCCr314ZxbAlQ28X4N1/UccF3gmyPtaVv5IT4o2BPsYYhIBwskKPoDxQ2el3htTfZxztUCVUBKgMs2luacK/Me7wLSmupkZrebmc/MfOXl5QFshpyuWU99xnf+mM+iTbuDPRQR6UAhPZnt/N+q1OQ3KznnnnLO5TrnclNTW70DXU7T1t0H2XPwOAA/mrOaz8sPBXlEItJRAgmKUmBAg+cZXluTfcwsBkgCKgJctrHdZpburSsd0LmOEJDnKyYmynj9rvPpEhPF7X/ycfBYTbCHJSIdIJCgyAeyzSzLzGLxT04vaNRnAXCz93g68L53NLAAmOVdFZUFZAPLW3m/huu6GXgtgDFKO6qpq+fllaVcdkYaozKSeOxbZ7G94gj35K2hvl5fpSsS7loNCm/O4S5gIbAJyHPObTCzX5rZtV63p4EUMysE7sG7Usk5twHIAzYCbwN3OufqAMzsRWApMNzMSszsVm9d9wOXm9lW4DLvuQTRok17qDhczcwJGQCcMySFf7nqDN7duJtH398a5NGJSHsz/y/+nVtubq5T9dj2c+uz+azfWcUn/3wJMdH+3y2cc/x43lpeWlnCI7PGMm1sa9coiEioMbMVzrnc1vqF9GS2BN/uA8dYXLCHb56VcTIkAMyM/7zhTCZmJfOTeWvJ397Ulc4iEg4UFNKil1aWUO9gZu6Av3ktLiaap749noxeXbn9Tz627z0chBGKSHtTUEiznHPM85UwMSuZzN7dm+zTs1ssf7xlAmbGLc/ms+9wdQePUkTam4JCmrW8qJKivYe5sYmjiYYGpXTnDzeNp3T/Ub73Jx9Hq+s6aIQi0hEUFNKsPF8JPeJiuHJU31b7jh+UzP/cOJYVX+zjBy+soKauvgNGKCIdQUEhTTp4rIY315VxzZh+dIsN7KvVrxqVzn9eP4rFBeX8eJ7usRAJF4F9AkjEeWNtGUdr6rhxQsunnRqbPXEg+45U8+DbBSR17cIvrh2Jvz6kiHRWCgpp0tz8Yoal9WBMRtIpL3vHhUPYf6SGp5Zso2fXLtxzxfB2GKGIdBQFhfyNLbsPsrp4Pz/7xhlf62jAzLjvyhFUHanh0fcLiY2J4q5LstthpCLSERQU8jfy8ovpEm1cP+7r323tvyFvFDV19Tz0zhbMjDsvHtqGoxSRjqKgkK+orq3nlVX+AoApPeJOa13RUcavZ4zBAb9eWACgsBDphBQU8hXvb97tLwDYyr0TgYqOMh6aMQbwh4UZ/OAihYVIZ6KgkK/I85XQNzGeycPa7sugToSFc44H3y6gptbxw0uH6mookU5CQSEn7ao6xgcFe7jjoiFER7Xth3h0lPHfM8cSEx3Fw+9toepoDT/7xhlEtfH7iEjbU1DISScKAM4Y3zannRqLjjIe/OZoEuJjeOaTIg4cq+H+G0Z9pSqtiIQeBYUAJwoAFnN2CwUA20JUlPGvV+eQ1LUL//PeVg4dq+WR2WOJi4lut/cUkdOjX+UEgGVFlWyvOHLKd2J/HWbGjy4bxr9encPbG3bx3WfzOaDv3xYJWQoKASDPV0xCXAxXnpneYe/53fOz+O8ZY1i2rZIZv1/Kzv1HO+y9RSRwCgrhwIkCgGP70TW2Y08BfXN8Bs/eMpGd+49y3WOfsL60qkPfX0Rap6AQ3lhTxrGa+ja7d+JUnZ/dm3l3nENMlHHjk0tZXLAnKOMQkaYpKIS5vmKGpyV8rQKAbWVE30ReufM8BqV057bnfDy/dDvOqUy5SChQUES4gl0HWVO8n5kTBgT9Bri0xHjyvn8OFw5L5f+9toH/+8o6qmv1BUgiwaagiHB5vtMvANiWesTF8IebcvnBRUN4cXkxs//wGXsOHgv2sEQimoIigp0oAHh5ThrJ3WODPZyToqOMn04dwe++NY6NOw9w7W8/YW3J/mAPSyRiKSgi2KJNu6k8XM2MIE1it+bq0f2Yf8c5REcZ059YyjxfcbCHJBKRFBQRLM9X7C8AmN12BQDb2sh+SSy46zxyB/XiJ/PX8uN5azhaXRfsYYlEFAVFhNpVdYwPt5QzfXxGmxcAbGspPeJ4/taz+eElQ3lpZQnTHvuYwj0Hgz0skYihoIhQJwsA5mYEeygBiY4y7rliOM/dMpGKQ9Vc+7tPeGVVSbCHJRIRFBQRqL7ekecrZtLgZAaltF8BwPYweVgqf/3hBZzZL4m7567hp/PXcPh4bbCHJRLWFBQRaPn2SnZ0UAHA9tA3KZ6/fO9sfnDREOatKOGqRz9i5Rf7gj0skbCloIhAefn+AoBTR3ZcAcC2FhMdxU+njmDO9yZRW+eY8cRSHn53C7V1ukFPpK0FFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTWlunmV1qZivNbLWZfWxm+oLlNnTgWA1vri/j2iAUAGwPZw9O4a0fXcC0Mf14ZNFWpj+xlKK9h4M9LJGw0mpQmFk08BhwJZADzDaznEbdbgX2OeeGAg8DD3jL5gCzgJHAVOBxM4tuZZ2/B/7OOTcW+Avws9PaQvmK19fsDGoBwPaQGN+F39w4lt/OHse28kNc9chHPPtJEfX1qhUl0hYCOaKYCBQ657Y556qBOcC0Rn2mAc95j+cDl5q/cNA0YI5z7rhzrggo9NbX0jodkOg9TgJ2fr1Nk6bk5Rczom8Co4NYALC9XDOmHwvvnszErGR+/vpGZj65lG3lh4I9LJFOL5Cg6A80vCW2xGtrso9zrhaoAlJaWLaldd4GvGlmJcC3gfubGpSZ3W5mPjPzlZeXB7AZsnnXAdaUVDEzN/gFANtLelJXnr1lAg/NGMOW3QeZ+shHPPHh55q7EDkNoTiZfTdwlXMuA/gj8JumOjnnnnLO5TrnclNTQ/fO4lCSl19Cl2jjuhApANhezIzp4zN4754LuXh4Kve/tZkbfv8pm3cdCPbQRDqlQIKiFGh4QjvDa2uyj5nF4D9lVNHCsk22m1kqMMY5t8xrnwucG9CWSIv8BQBLuCKnb0gVAGxPfRLjeeLvx/O7b42jdN9RvvHox/znm5t034XIKQokKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzv/N86swCY5V0VlQVkA8tbWOc+IMnMhnnruhzY9PU3T054b9Nu9h2p6TR3YrcVM+Pq0f14754LmTE+g6eWbOOy33zIW+vK9MVIIgGKaa2Dc67WzO4CFgLRwDPOuQ1m9kvA55xbADwNPG9mhUAl/g9+vH55wEagFrjTOVcH0NQ6vfbvAS+ZWT3+4Phum25xhMrzFZOeFM8FIVwAsD316h7L/d8czYzcAfzs1fXc8cJKLhyWyi+uHUlm7851d7pIR7Nw+K0qNzfX+Xy+YA8jZJVVHeW8+9/nzouH8k9XDA/2cIKutq6ePy3dwW/e3UJ1XT3fv3AI379wMN1iW/29SSSsmNkK51xua/1CcTJb2thLK7wCgOPD596J0xETHcV3z89i0T9dyJSRfXl00VYueehDXl5ZonsvRJqgoAhz/gKAJZwzOIWBKd2CPZyQkpYYz29nj2Pe98+hT2Ic9+St4brHP8G3vTLYQxMJKQqKMLesqJIvKjtvAcCOMCEzmVd/cB6/mTmG3QeOMf2Jpdz5l5UUVx4J9tBEQoJOyoa5PF8xCfExTD2zb7CHEtKioowbzspg6pl9efLDbTy55HPe3bCbv580iDsvHkJKj7hgD1EkaHREEcaqjtbw5roypo3tR3yXzl8AsCN0i43h7suHsfjHF3H9uP48+2kRkx9czMPvbuHgsZpgD08kKBQUYez1NTs5XhteBQA7SnpSVx6YPpp37p7M5GGpPLJoKxf++gOe/riIYzX6zm6JLAqKMJbn8xcAHNU//AoAdpShfRL4/d+P57U7zyMnPZF/f2Mjlzz0AS8u/4LqWtWPksigoAhTm8oOsDbMCwB2pDEDevLn287mhdvOJjUxnvteXsfFD33Anz/bwfFaHWFIeFNQhKk8XzGx0VFcH+YFADvaeUN78+oPzuXZWybQJzGOn726ngsf/IDnPt2uU1ISthQUYeh4bR2vrirl8pFp9IqQAoAdycy4aHgfXr7jXP5869kMSO7Kvy3YwOQHF/P0x0UcrVZgSHjR5bFh6L2Ne9h3pEaT2O3MzDg/uzfnDU1h6bYKHl20lX9/YyO/e38rN52TyU3nDNJltRIWFBRhKM9XTL+keM4f2jvYQ4kIZsa5Q3pz7pDe5G+v5MkPP+eRRVt5csnnzBg/gNsuyGJQigoPSueloAgzO/cfZcnWcv7h4qFER2kSu6NNyExmQmYyhXsO8tSSbczNL+aFZTu48sx0bp88mDEDegZ7iCKnTEERZl5aUYJzMEOnnYJqaJ8EHpw+hh9fMZw/frqdP3+2g7+uK2NiVjLfOTeTK3LSiInWFKF0DiozHkbq6x0XPrSYAb268ZfvTQr2cKSBQ8drmbP8C579dDsl+47SLymev5s0iNkTB0bMNw5K6FGZ8Qj0WVEFxZVHVQAwBPWIi+G2Cwbz4U8u5g835ZKV2p1fLyxg0n8t4sfz1rC+tCrYQxRplk49hZG8fH8BwCkjVQAwVEVHGZfnpHF5Thpbdx/kuaXbeXllKfNXlDB+UC9uOmcQU0b2VW0uCSkKijBRdbSGt9bvYmbuAH3IdBLZaQn8x3Wj+MmUEcxfUcLzS7fzj3NW07NbF24Yl8HsiQPITksI9jBFFBThYoEKAHZaSV27cOv5WdxybiZLt1Xw4vIveP6z7TzzSRG5g3oxe+JArhqVTtdY/QIgwaHJ7DBxzW8/prbe8eYPz1dtpzBQceg4L68s5cXlX7Bt72ES4mO4YVx/Zk4YwMh+KvIobSPQyWwdUYSBjTsPsK60in+7JkchESZSesTxvcmDue2CLJYXVfLi8i94Mb+Y55buYETfBL55VgbTxvajT2J8sIcqEUBBEQZOFAC8bqwKAIYbM+PswSmcPTiFnx+p5vW1Zby0ooRfvbmJ/3prE5OHpXLDWRlckZOmuSlpNwqKTu54bR2vrlYBwEjQs1ss3540iG9PGsTn5Yd4eWUJr6ws5YcvriIhLoZvjE7nhrMyyB3UiyjdlS9tSEHRyb27cTf7j9RwoyaxI8qQ1B78ZMoI/uny4XxWVMFLK0pZsGYnc/L9db6uHtOPq0enM6p/kk5HymnTZHYnd9Mzy/l8zyGW/PRi1XaKcEeqa3lnw25eX7OTJVvLqalzDErpxjWj+3H1mHSGpyUoNOQrNJkdAUr3H+WjreX8wyXZCgmhW2wM143rz3Xj+lN1pIaFG3bx+tqdPP5BIb9bXEh2nx5cPbof14xJZ3Bqj2APVzoRBUUndrIA4PiMYA9FQkxSty7MnDCAmRMGsPfQcd5aV8bra8t4+L0tPPzeFkb0TWDKyL5MGdmXM9J1pCEt06mnTqq+3jH514sZlNKNF25TAUAJTFnVUd5ct4uF63eRv6MS52BgcjemjExjysi+nDVQE+GRRKeewtxn2yoo2XeUn0wZHuyhSCeSntSVW8/P4tbzsyg/eJz3Nu1m4YZdPPvpdv7wURGpCXFcnpPG1JF9mTQ4hdgY1Q0VBUWnNddXTKIKAMppSE2IY/bEgcyeOJADx2pYvHkPCzfs4tVVpfxl2RckxMcweVgql47ow0XD+6gcegQLKCjMbCrwCBAN/K9z7v5Gr8cBfwLGAxXAjc657d5r9wG3AnXAD51zC1tap/lPlv4HMMNb5vfOuUdPbzPDS9URfwHAWRNUAFDaRmJ8F6aN7c+0sf05VlPHx1v38s7GXSwuKOeva8swg3EDenLJiD5cMiJN8xoRptWgMLNo4DHgcqAEyDezBc65jQ263Qrsc84NNbNZwAPAjWaWA8wCRgL9gPfMbJi3THPr/A4wABjhnKs3sz5tsaHhZMGaUqpVAFDaSXyXaC7LSeOynDTq6x3rd1bx/uY9vL95Dw+9s4WH3tlCelI8F4/owyXD+3De0N4qWBjmAjmimAgUOue2AZjZHGAa0DAopgE/9x7PB37nHRlMA+Y4544DRWZW6K2PFtZ5B/At51w9gHNuz9ffvPA011dMTnoiZ/ZXcThpX1FRxuiMnozO6MmPLhvGngPH+KCgnEWbd/Oad4oqLiaKiVnJTM5O5YJhvXW/RhgKJCj6A8UNnpcAZzfXxzlXa2ZVQIrX/lmjZU8UJGpunUPwH41cD5TjP121tfGgzOx24HaAgQMHBrAZ4WHDzirWlx7g59fkBHsoEoH6JMafvOz2eG0dy4sqWby5nI+2lvOrNzfBm/65jwuyezM5O5XzhvYmNSEu2MOW0xSKk9lxwDHnXK6Z3QA8A1zQuJNz7ingKfBfHtuxQwyeeb4SfwHAcSoAKMEVFxPNBdmpXJCdCvgvvf1o614+2rqXxZv38PLKUgBy0hO5YJg/OHIzexEXo9NUnU0gQVGKf87ghAyvrak+JWYWAyThn9Ruadnm2kuAl73HrwB/DGCMEeFYTR2vrCrlipFp9OymK1AktKQndWVm7gBm5g6gvt6xYecBlmwtZ8mWcp75uIgnP9xGfJcocgclc86QFCYNTmF0RhJdonUJbqgLJCjygWwzy8L/YT4L+FajPguAm4GlwHTgfeecM7MFwF/M7Df4J7OzgeWAtbDOV4GLgSLgQmDL1966MPPuxt1UHa3hxgmaxJbQFhVljMpIYlRGEndePJTDx2tZVlTBki17+WxbBb9eWABAt9hoJmQmM2lwCucMSeHMfonEKDhCTqtB4c053AUsxH8p6zPOuQ1m9kvA55xbADwNPO9NVlfi/+DH65eHf5K6FrjTOVcH0NQ6vbe8H3jBzO4GDgG3td3mdm55vmL69+zKeUN6B3soIqeke1wMl4xI45IRaYD/G/yWFVXy2bYKln5ewQNvbwYgIS6GCVnJnOMFxxnpiapjFgJUwqOTKNl3hAseXMwPL8nm7suHtb6ASCdSfvC4PzS2VfDZ5xVs23sYgIT4GMYP6sWEzGRyB/VizICeuneoDamER5h5aYV/CmdGrgoASvhJTYjjmjH9uGZMPwB2VR3js20VLN9eSX5RJR8U+E9VdYk2RvVP8geHFx76wq72pyOKTuBEAcDMlO78+bbGVyaLhL99h6tZsWMf+Tsq8W3fx9qS/dTU+T+7hvbpwYTMXuQOSmZCZjIDkrvqPo4A6YgijCz1CgD+dOqIYA9FJCh6dY89ebc4+K8AXFtSRf72SnzbK3ljbRkvLvffmpXSPZaxA3oybmBPxg3sxeiMJBLiuwRz+J2egqITmJtfTFLXLlzh/ScRiXTxXaKZmJXMxKxkwH/UvWXPQfK372P1F/tZXbyPRZv9RR3MILtPDy88ejF2QE+GpSVokvwUKChCXNWRGt7esIvZKgAo0qyoKGNE30RG9E3k25MGAf7/O2tK9rPKC453Nu4mz1cCQPfYaEZlJJ0MjtEZSfRNjNcpq2YoKELca14BwBkqAChySpK6dWHysFQmD/PfOe6cY0fFEVYV+486VhXv5w9LtlFb75/r6N0jjlH9ExnVP4lRGT0Z1T+JtMQ4hQcKipCX5ytmZD8VABQ5XWZGZu/uZPbuzvXj/FcPHqupY8POA6wvrWJtSRXrS6v4cEs5XnbQu0ccozOSOLN/EqP7+28gTEuMD+JWBIeCIoSdKAD4i2tHBnsoImEpvks04wf1YvygXifbjlTXsqnsAOtKqlhb6g+PDwr2nAyP1IQ4Rvf3h0dOv0Ry0hPJ6BXeV1opKEJYXn4xsTFRTBvbL9hDEYkY3WJjGD8omfGDkk+2HamuZePOA6wrrWJdSRXrSqtY3CA8EuJjOKNvIjn9EjkjPYEz0hMZlpYQNvOKCooQdaymjldX72TKyL4qACgSZN1iY/w3+GV+NTwKdh1kU9lBNpZVsansIHm+Yo5U1wEQHWUM7t3dCw//n5z0xE5Zdl1BEaLeOVEAUJPYIiGpW2wM4wb2YtzAL09b1dc7vqg8wsayA2wqO8DGnQfIL6rktdU7T/bp3SOOM9ITGJ6WwLC+/r+z03rQLTZ0P45Dd2QRbp5XAPDcISnBHoqIBCgq6ssJ86tGpZ9s33e4mk27/MGxqewgm8oO8KeiHVTX1p/sMyC5qz880hIY3tf/9+DU7iHx/R0KihBUsu8IHxfu5R8vzSZKNwWJdHq9usdy7pDenNug8nNdvWNHxWG27D7Ilt2HKNh9kC27DvJBQfnJS3ajo4zMlG4ng+PEn8yUbh1ajl1BEYLmr/DfFDR9vAoAioSr6ChjcGoPBqf2YOqZX7ZX19ZTtPfwyeDYsvsgG3ce4K31uzhRmi82Ooqs3t0ZmtaDe6eOYEByt3Ydq4IixNTXO+b5Sjh/aG8yerXvzheR0BMbE8Xwvv7TT4z5sv1odR2Few55RyAHKdxziHUlVcTGtP+RhYIixHz6eQWl+49y75UqACgiX+rqlR0ZldHxN9/qOwdDzFyfvwDg5SoAKCIhQkERQvYfqWbhhl1cP65/2NyoIyKdn4IihLy2eqdXAFCT2CISOhQUISTPV8yZ/RMZ2U8FAEUkdCgoQsT60io27DzATN2JLSIhRkERIvJ8XgHAMf2DPRQRka9QUISAYzV1vLqqlKkj+5LUTd/tKyKhRUERAhZu2MWBY7XcOEGnnUQk9CgoQsA8XwkZvbpyzmAVABSR0KOgCLLiSn8BwBnjB6gAoIiEJAVFkM1fUYIZTNe9EyISohQUQVRX75i/wl8AsH/PrsEejohIkxQUQfTp53sp3X9Uk9giEtIUFEE0N7+Ynt1UAFBEQpuCIkj2H6nmnQ27uW5s/5D4qkMRkeYEFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTTmGdj5rZoa+5XSHv1VWlVNfVq2SHiIS8VoPCzKKBx4ArgRxgtpnlNOp2K7DPOTcUeBh4wFs2B5gFjASmAo+bWXRr6zSzXKDXaW5bSMvzlTCqfxI5/RKDPRQRkRYFckQxESh0zm1zzlUDc4BpjfpMA57zHs8HLjUz89rnOOeOO+eKgEJvfc2u0wuRXwM/Pb1NC13rS6vYWHaAmbokVkQ6gUCCoj9Q3OB5idfWZB/nXC1QBaS0sGxL67wLWOCcK2tpUGZ2u5n5zMxXXl4ewGaEjjxfMXExUVw7VgUARST0hdRktpn1A2YAv22tr3PuKedcrnMuNzU1tf0H10ZOFgA8sy9JXVUAUERCXyBBUQo0nHHN8Nqa7GNmMUASUNHCss21jwOGAoVmth3oZmaFAW5Lp3CyAKAmsUWkkwgkKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzvnHNe+yzvqqgsIBtY3tw6nXN/dc71dc5lOucygSPeBHnYyPMVMyC5K5NUAFBEOomY1jo452rN7C5gIRANPOOc22BmvwR8zrkFwNPA895v/5X4P/jx+uUBG4Fa4E7nXB1AU+ts+80LLcWVR/iksIJ7Lh+mAoAi0mm0GhQAzrk3gTcbtf1rg8fH8M8tNLXsr4BfBbLOJvr0CGR8ncU8rwDgN8fraicR6TxCajI7nNXVO+b7irkgO1UFAEWkU1FQdJBPCveys+qYJrFFpNNRUHSQub5ienXrwmU5fYI9FBGRU6Kg6AD7Dlfz7obdXDdOBQBFpPNRUHSAV1erAKCIdF4KinbmnGNufjGjM5I4I10FAEWk81FQtLP1pQfYvOsgM3Q0ISKdlIKinZ0sADimX7CHIiLytSgo2tGxmjpeXV3KlSoAKCKdmIKiHb29fhcHj9Uyc4JOO4lI56WgaEcnCwBmqQCgiHReCop28kXFET79vIKZ4weoAKCIdGoKinYyf0WxCgCKSFhQULSDunrHvBUlTM5OpZ8KAIpIJ6egaAcfF+6lrOoYN2oSW0TCgIKiHeTl+wsAXnqGCgCKSOenoGhjlYereWfjLq4fl6ECgCISFhQUbezVVaXU1DlmTtAktoiEBwVFG3LOkecrZkxGEiP6qgCgiIQHBUUbWldapQKAIhJ2FBRt6GQBwLEqACgi4UNB0UaO1dTx2uqdXDUqncR4FQAUkfChoGgjJwsA6rSTiIQZBUUbmZtfzMDkbpydlRzsoYiItCkFRRvYUXGYpdsqmJmboQKAIhJ2FBRtYP6KEqJUAFBEwpSC4jTV1Tvmryhh8rBU0pNUAFBEwo+C4jR9tLWcsqpjmsQWkbCloDhNeb5ikrvHctkZacEeiohIu1BQnIbKw9W8u3E314/rT2yM/ilFJDwF9OlmZlPNrMDMCs3s3iZejzOzud7ry8wss8Fr93ntBWY2pbV1mtkLXvt6M3vGzEL27rVXThQA1GknEQljrQaFmUUDjwFXAjnAbDPLadTtVmCfc24o8DDwgLdsDjALGAlMBR43s+hW1vkCMAIYBXQFbjutLWwnzjnm+YoZM6Anw/smBHs4IiLtJpAjiolAoXNum3OuGpgDTGvUZxrwnPd4PnCpmZnXPsc5d9w5VwQUeutrdp3OuTedB1gOhOQ1p2tL/AUAZ+aG5PBERNpMIEHRHyhu8LzEa2uyj3OuFqgCUlpYttV1eqecvg283dSgzOx2M/OZma+8vDyAzWhbeb5i4rtEcc0YFQAUkfAWyjOwjwNLnHMfNfWic+4p51yucy43NTW1Qwd2tLqOBat3ctWZKgAoIuEvJoA+pUDD2doMr62pPiVmFgMkARWtLNvsOs3s34BU4P8EML4O9/aGMg4er2XmBE1ii0j4C+SIIh/INrMsM4vFPzm9oFGfBcDN3uPpwPveHMMCYJZ3VVQWkI1/3qHZdZrZbcAUYLZzrv70Nq99zM0vZlCKCgCKSGRo9YjCOVdrZncBC4Fo4Bnn3AYz+yXgc84tAJ4GnjezQqAS/wc/Xr88YCNQC9zpnKsDaGqd3ls+AewAlvrnw3nZOffLNtvi07Sj4jCfbavkJ1OG441PRCSsBXLqCefcm8Cbjdr+tcHjY8CMZpb9FfCrQNbptQc0pmCZ5/MKAJ6lq51EJDKE8mR2yDlRAPDCYan0TYoP9nBERDqEguIULNlazq4DKgAoIpFFQXEK8vKLSekey6UqACgiEURBEaCKQ8d5b5MKAIpI5NEnXoBOFgDUvRMiEmEUFAFwzpHnK2bsgJ4MS1MBQBGJLAqKAKwpqWLL7kOaxBaRiKSgCMCXBQDTgz0UEZEOp6BoxdHqOl5fvZOrRqWToAKAIhKBFBSteGu9vwDgjTrtJCIRSkHRirn5xWSmdGOiCgCKSIRSULRg+97DLCuqZEbuABUAFJGIpaBowbwVxSoAKCIRT0HRjNq6euavKOGi4X1UAFBEIpqCohkfbd3L7gPHmZmrowkRiWwKimbM9QoAXjJCBQBFJLIpKJqgAoAiIl/Sp2ATXllVSm2940YVABQRUVA05pxjbn4x4wb2JFsFAEVEFBSNrS7ez9Y9KgAoInKCgqKRPF8JXbtEc/VoFQAUEQEFxVccqa7l9TUqACgi0pCCooG31u3i0PFaTWKLiDSgoGhgrq+YrN7dmZDZK9hDEREJGQoKT9HewywvqmRGboYKAIqINKCg8MzzqQCgiEhTFBT4CwC+tLKEi4f3IS1RBQBFRBpSUABLtpaz+8BxZujeCRGRv6GgwF8AsHePWC49o0+whyIiEnIiPij2HjrOok17uH5cf7pER/w/h4jI34j4T8ZXVqoAoIhISwIKCjObamYFZlZoZvc28Xqcmc31Xl9mZpkNXrvPay8wsymtrdPMsrx1FHrrjD3NbWyWc448XzFnDezJ0D4qACgi0pRWg8LMooHHgCuBHGC2meU06nYrsM85NxR4GHjAWzYHmAWMBKYCj5tZdCvrfAB42FvXPm/d7WKVCgCKiLQqkCOKiUChc26bc64amANMa9RnGvCc93g+cKn571qbBsxxzh13zhUBhd76mlynt8wl3jrw1nnd1966VszzFfsLAI7p115vISLS6QUSFP2B4gbPS7y2Jvs452qBKiClhWWba08B9nvraO69ADCz283MZ2a+8vLyADbjbw1M7s53zsukR1zM11peRCQSdNpPSOfcU8BTALm5ue7rrOOOi4a06ZhERMJRIEcUpUDDk/gZXluTfcwsBkgCKlpYtrn2CqCnt47m3ktERDpQIEGRD2R7VyPF4p+cXtCozwLgZu/xdOB955zz2md5V0VlAdnA8ubW6S2z2FsH3jpf+/qbJyIip6vVU0/OuVozuwtYCEQDzzjnNpjZLwGfc24B8DTwvJkVApX4P/jx+uUBG4Fa4E7nXB1AU+v03vKfgTlm9h/AKm/dIiISJOb/Jb5zy83NdT6fL9jDEBHpVMxshXMut7V+EX9ntoiItExBISIiLVJQiIhIixQUIiLSorCYzDazcmDH11y8N7C3DYfTGWibI4O2Ofyd7vYOcs6lttYpLILidJiZL5BZ/3CibY4M2ubw11Hbq1NPIiLSIgWFiIi0SEHhFRaMMNrmyKBtDn8dsr0RP0chIiIt0xGFiIi0SEEhIiItiuigMLOpZlZgZoVmdm+wx3MqzGyAmS02s41mtsHM/tFrTzazd81sq/d3L6/dzOxRb1vXmtlZDdZ1s9d/q5nd3KB9vJmt85Z51Puq2qDzvnd9lZm94T3PMrNl3jjneqXr8crbz/Xal5lZZoN13Oe1F5jZlAbtIfczYWY9zWy+mW02s01mdk6472czu9v7uV5vZi+aWXy47Wcze8bM9pjZ+gZt7b5fm3uPFjnnIvIP/vLmnwODgVhgDZAT7HGdwvjTgbO8xwnAFiAHeBC412u/F3jAe3wV8BZgwCRgmdeeDGzz/u7lPe7lvbbc62veslcGe7u9cd0D/AV4w3ueB8zyHj8B3OE9/gHwhPd4FjDXe5zj7e84IMv7OYgO1Z8J/N8df5v3OBboGc77Gf/XHxcBXRvs3++E234GJgNnAesbtLX7fm3uPVoca7D/EwTxh/EcYGGD5/cB9wV7XKexPa8BlwMFQLrXlg4UeI+fBGY36F/gvT4beLJB+5NeWzqwuUH7V/oFcTszgEXAJcAb3n+CvUBM4/2K//tOzvEex3j9rPG+PtEvFH8m8H9bZBHehSeN91847mf8QVHsffjFePt5SjjuZyCTrwZFu+/X5t6jpT+RfOrpxA/jCSVeW6fjHWqPA5YBac65Mu+lXUCa97i57W2pvaSJ9mD7H+CnQL33PAXY75yr9Z43HOfJbfNer/L6n+q/RTBlAeXAH73Tbf9rZt0J4/3snCsFHgK+AMrw77cVhPd+PqEj9mtz79GsSA6KsGBmPYCXgB855w40fM35f2UIm+ufzexqYI9zbkWwx9KBYvCfnvi9c24ccBj/6YKTwnA/9wKm4Q/JfkB3YGpQBxUEHbFfA32PSA6KUmBAg+cZXlunYWZd8IfEC865l73m3WaW7r2eDuzx2pvb3pbaM5poD6bzgGvNbDswB//pp0eAnmZ24mt9G47z5LZ5rycBFZz6v0UwlQAlzrll3vP5+IMjnPfzZUCRc67cOVcDvIx/34fzfj6hI/Zrc+/RrEgOinwg27uSIhb/JNiCII8pYN4VDE8Dm5xzv2nw0gLgxJUPN+OfuzjRfpN39cQkoMo7/FwIXGFmvbzf5K7Af/62DDhgZpO897qpwbqCwjl3n3MuwzmXiX9/ve+c+ztgMTDd69Z4m0/8W0z3+juvfZZ3tUwWkI1/4i/kfiacc7uAYjMb7jVdiv876MN2P+M/5TTJzLp5YzqxzWG7nxvoiP3a3Hs0L5iTVsH+g/9Kgi34r4D4l2CP5xTHfj7+Q8a1wGrvz1X4z80uArYC7wHJXn8DHvO2dR2Q22Bd3wUKvT+3NGjPBdZ7y/yORhOqQd7+i/jyqqfB+D8ACoF5QJzXHu89L/ReH9xg+X/xtquABlf5hOLPBDAW8Hn7+lX8V7eE9X4GfgFs9sb1PP4rl8JqPwMv4p+DqcF/5HhrR+zX5t6jpT8q4SEiIi2K5FNPIiISAAWFiIi0SEEhIiItUlCIiEiLFBQiItIiBYWIiLRIQSEiIi36/zob5nVzA95IAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "lrs=[]\n", - "for i in range(100000):\n", - " sc.step()\n", - " lrs.append(sc.get_lr())\n", - "xs = list(range(100000))\n", - "plt.plot(xs, lrs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e613fe16", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f0fd9f40", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/audio_feature.ipynb b/.notebook/audio_feature.ipynb deleted file mode 100644 index 04b4a3924..000000000 --- a/.notebook/audio_feature.ipynb +++ /dev/null @@ -1,1207 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 94, - "id": "matched-camera", - "metadata": {}, - "outputs": [], - "source": [ - "from nnAudio import Spectrogram\n", - "from scipy.io import wavfile\n", - "import torch\n", - "import soundfile as sf\n", - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "id": "quarterly-solution", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[43 75 69 ... 7 6 3]\n", - "[43 75 69 ... 7 6 3]\n", - "[43 75 69 ... 7 6 3]\n" - ] - } - ], - "source": [ - "import scipy.io.wavfile as wav\n", - "\n", - "rate,sig = wav.read('./BAC009S0764W0124.wav')\n", - "sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n", - "sample, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", - "print(sig)\n", - "print(song)\n", - "print(sample)" - ] - }, - { - "cell_type": "code", - "execution_count": 96, - "id": "middle-salem", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "16000\n", - "[43 75 69 ... 7 6 3]\n", - "(83792,)\n", - "int16\n", - "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.2733 seconds\n", - "tensor([[[[-4.0940e+03, 1.2600e+04],\n", - " [ 8.5108e+03, -5.4930e+03],\n", - " [-3.3631e+03, -1.7904e+03],\n", - " ...,\n", - " [ 8.2279e+03, -9.3340e+03],\n", - " [-3.1990e+03, 2.0969e+03],\n", - " [-1.2669e+03, 4.4488e+03]],\n", - "\n", - " [[ 3.4886e+03, -9.9620e+03],\n", - " [-4.5364e+03, 4.1907e+02],\n", - " [ 2.5074e+03, 7.1339e+03],\n", - " ...,\n", - " [-5.4819e+03, 3.9258e+01],\n", - " [ 4.7221e+03, 6.5887e+01],\n", - " [ 9.6492e+02, -3.4386e+03]],\n", - "\n", - " [[-3.4947e+03, 9.2981e+03],\n", - " [-7.5164e+03, 8.1856e+02],\n", - " [-5.3766e+03, -9.0889e+03],\n", - " ...,\n", - " [ 1.4317e+03, 5.7447e+03],\n", - " [-3.1178e+03, 3.0740e+03],\n", - " [-3.4351e+03, 5.6900e+02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 6.7112e+01, -4.5737e+00],\n", - " [-9.6295e+00, 3.5554e+01],\n", - " [ 1.8527e+00, -1.0491e+01],\n", - " ...,\n", - " [-1.1157e+01, 3.4423e+00],\n", - " [ 3.1193e+00, -4.4388e+00],\n", - " [-8.8242e+00, 8.0324e+00]],\n", - "\n", - " [[-6.5080e+01, 2.9543e+00],\n", - " [ 3.9992e+01, -1.3836e+01],\n", - " [-9.2803e+00, 1.0318e+01],\n", - " ...,\n", - " [ 4.2928e+00, 9.2397e+00],\n", - " [ 3.6642e+00, 9.4680e+00],\n", - " [ 4.8932e+00, -2.5199e+01]],\n", - "\n", - " [[ 4.7264e+01, -1.0721e+00],\n", - " [-6.0516e+00, -1.4589e+01],\n", - " [ 1.3127e+01, 1.4995e+00],\n", - " ...,\n", - " [ 1.7333e+01, -1.4380e+01],\n", - " [-3.6046e+00, -6.1019e+00],\n", - " [ 1.3321e+01, 2.3184e+01]]]])\n" - ] - } - ], - "source": [ - "sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n", - "print(sr)\n", - "print(song)\n", - "print(song.shape)\n", - "print(song.dtype)\n", - "x = song\n", - "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", - "\n", - "spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n", - " window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n", - " fmin=50,fmax=8000, sr=sr) # Initializing the model\n", - "\n", - "spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", - "print(spec)" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "id": "finished-sterling", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "16000\n", - "[43 75 69 ... 7 6 3]\n", - "(83792,)\n", - "int16\n", - "True\n", - "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.2001 seconds\n", - "torch.Size([1, 1025, 164, 2])\n", - "tensor([[[[-4.0940e+03, 1.2600e+04],\n", - " [ 8.5108e+03, -5.4930e+03],\n", - " [-3.3631e+03, -1.7904e+03],\n", - " ...,\n", - " [ 8.2279e+03, -9.3340e+03],\n", - " [-3.1990e+03, 2.0969e+03],\n", - " [-1.2669e+03, 4.4488e+03]],\n", - "\n", - " [[ 3.4886e+03, -9.9620e+03],\n", - " [-4.5364e+03, 4.1907e+02],\n", - " [ 2.5074e+03, 7.1339e+03],\n", - " ...,\n", - " [-5.4819e+03, 3.9258e+01],\n", - " [ 4.7221e+03, 6.5887e+01],\n", - " [ 9.6492e+02, -3.4386e+03]],\n", - "\n", - " [[-3.4947e+03, 9.2981e+03],\n", - " [-7.5164e+03, 8.1856e+02],\n", - " [-5.3766e+03, -9.0889e+03],\n", - " ...,\n", - " [ 1.4317e+03, 5.7447e+03],\n", - " [-3.1178e+03, 3.0740e+03],\n", - " [-3.4351e+03, 5.6900e+02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 6.7112e+01, -4.5737e+00],\n", - " [-9.6295e+00, 3.5554e+01],\n", - " [ 1.8527e+00, -1.0491e+01],\n", - " ...,\n", - " [-1.1157e+01, 3.4423e+00],\n", - " [ 3.1193e+00, -4.4388e+00],\n", - " [-8.8242e+00, 8.0324e+00]],\n", - "\n", - " [[-6.5080e+01, 2.9543e+00],\n", - " [ 3.9992e+01, -1.3836e+01],\n", - " [-9.2803e+00, 1.0318e+01],\n", - " ...,\n", - " [ 4.2928e+00, 9.2397e+00],\n", - " [ 3.6642e+00, 9.4680e+00],\n", - " [ 4.8932e+00, -2.5199e+01]],\n", - "\n", - " [[ 4.7264e+01, -1.0721e+00],\n", - " [-6.0516e+00, -1.4589e+01],\n", - " [ 1.3127e+01, 1.4995e+00],\n", - " ...,\n", - " [ 1.7333e+01, -1.4380e+01],\n", - " [-3.6046e+00, -6.1019e+00],\n", - " [ 1.3321e+01, 2.3184e+01]]]])\n", - "True\n" - ] - } - ], - "source": [ - "wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", - "print(sr)\n", - "print(wav)\n", - "print(wav.shape)\n", - "print(wav.dtype)\n", - "print(np.allclose(wav, song))\n", - "\n", - "x = wav\n", - "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", - "\n", - "spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n", - " window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n", - " fmin=50,fmax=8000, sr=sr) # Initializing the model\n", - "\n", - "wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", - "print(wav_spec.shape)\n", - "print(wav_spec)\n", - "print(np.allclose(wav_spec, spec))" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "id": "running-technology", - "metadata": {}, - "outputs": [], - "source": [ - "import decimal\n", - "\n", - "import numpy\n", - "import math\n", - "import logging\n", - "def round_half_up(number):\n", - " return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))\n", - "\n", - "\n", - "def rolling_window(a, window, step=1):\n", - " # http://ellisvalentiner.com/post/2017-03-21-np-strides-trick\n", - " shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)\n", - " strides = a.strides + (a.strides[-1],)\n", - " return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]\n", - "\n", - "\n", - "def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):\n", - " \"\"\"Frame a signal into overlapping frames.\n", - "\n", - " :param sig: the audio signal to frame.\n", - " :param frame_len: length of each frame measured in samples.\n", - " :param frame_step: number of samples after the start of the previous frame that the next frame should begin.\n", - " :param winfunc: the analysis window to apply to each frame. By default no window is applied.\n", - " :param stride_trick: use stride trick to compute the rolling window and window multiplication faster\n", - " :returns: an array of frames. Size is NUMFRAMES by frame_len.\n", - " \"\"\"\n", - " slen = len(sig)\n", - " frame_len = int(round_half_up(frame_len))\n", - " frame_step = int(round_half_up(frame_step))\n", - " if slen <= frame_len:\n", - " numframes = 1\n", - " else:\n", - " numframes = 1 + (( slen - frame_len) // frame_step)\n", - "\n", - " # check kaldi/src/feat/feature-window.h\n", - " padsignal = sig[:(numframes-1)*frame_step+frame_len]\n", - " if wintype is 'povey':\n", - " win = numpy.empty(frame_len)\n", - " for i in range(frame_len):\n", - " win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85 \n", - " else: # the hamming window\n", - " win = numpy.hamming(frame_len)\n", - " \n", - " if stride_trick:\n", - " frames = rolling_window(padsignal, window=frame_len, step=frame_step)\n", - " else:\n", - " indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(\n", - " numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T\n", - " indices = numpy.array(indices, dtype=numpy.int32)\n", - " frames = padsignal[indices]\n", - " win = numpy.tile(win, (numframes, 1))\n", - " \n", - " frames = frames.astype(numpy.float32)\n", - " raw_frames = numpy.zeros(frames.shape)\n", - " for frm in range(frames.shape[0]):\n", - " raw_frames[frm,:] = frames[frm,:]\n", - " frames[frm,:] = do_dither(frames[frm,:], dither) # dither\n", - " frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset\n", - " # raw_frames[frm,:] = frames[frm,:]\n", - " frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize\n", - "\n", - " return frames * win, raw_frames\n", - "\n", - "\n", - "def magspec(frames, NFFT):\n", - " \"\"\"Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n", - "\n", - " :param frames: the array of frames. Each row is a frame.\n", - " :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n", - " :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.\n", - " \"\"\"\n", - " if numpy.shape(frames)[1] > NFFT:\n", - " logging.warn(\n", - " 'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',\n", - " numpy.shape(frames)[1], NFFT)\n", - " complex_spec = numpy.fft.rfft(frames, NFFT)\n", - " return numpy.absolute(complex_spec)\n", - "\n", - "\n", - "def powspec(frames, NFFT):\n", - " \"\"\"Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n", - "\n", - " :param frames: the array of frames. Each row is a frame.\n", - " :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n", - " :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.\n", - " \"\"\"\n", - " return numpy.square(magspec(frames, NFFT))\n", - "\n", - "\n", - "def do_dither(signal, dither_value=1.0):\n", - " signal += numpy.random.normal(size=signal.shape) * dither_value\n", - " return signal\n", - " \n", - "def do_remove_dc_offset(signal):\n", - " signal -= numpy.mean(signal)\n", - " return signal\n", - "\n", - "def do_preemphasis(signal, coeff=0.97):\n", - " \"\"\"perform preemphasis on the input signal.\n", - "\n", - " :param signal: The signal to filter.\n", - " :param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.\n", - " :returns: the filtered signal.\n", - " \"\"\"\n", - " return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "id": "ignored-retreat", - "metadata": {}, - "outputs": [], - "source": [ - "def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n", - " wintype='hamming'):\n", - " highfreq= highfreq or samplerate/2\n", - " frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n", - " spec = magspec(frames, nfft) # nearly the same until this part\n", - " rspec = magspec(raw_frames, nfft)\n", - " return spec, rspec\n", - "\n", - "\n", - "\n", - "def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n", - " wintype='hamming'):\n", - " highfreq= highfreq or samplerate/2\n", - " frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n", - " return raw_frames" - ] - }, - { - "cell_type": "code", - "execution_count": 100, - "id": "federal-teacher", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "import torch\n", - "import torch.nn as nn\n", - "from torch.nn.functional import conv1d, conv2d, fold\n", - "import scipy # used only in CFP\n", - "\n", - "import numpy as np\n", - "from time import time\n", - "\n", - "def pad_center(data, size, axis=-1, **kwargs):\n", - "\n", - " kwargs.setdefault('mode', 'constant')\n", - "\n", - " n = data.shape[axis]\n", - "\n", - " lpad = int((size - n) // 2)\n", - "\n", - " lengths = [(0, 0)] * data.ndim\n", - " lengths[axis] = (lpad, int(size - n - lpad))\n", - "\n", - " if lpad < 0:\n", - " raise ParameterError(('Target size ({:d}) must be '\n", - " 'at least input size ({:d})').format(size, n))\n", - "\n", - " return np.pad(data, lengths, **kwargs)\n", - "\n", - "\n", - "\n", - "sz_float = 4 # size of a float\n", - "epsilon = 10e-8 # fudge factor for normalization\n", - "\n", - "def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,\n", - " freq_scale='linear', window='hann', verbose=True):\n", - "\n", - " if freq_bins==None: freq_bins = n_fft//2+1\n", - " if win_length==None: win_length = n_fft\n", - "\n", - " s = np.arange(0, n_fft, 1.)\n", - " wsin = np.empty((freq_bins,1,n_fft))\n", - " wcos = np.empty((freq_bins,1,n_fft))\n", - " start_freq = fmin\n", - " end_freq = fmax\n", - " bins2freq = []\n", - " binslist = []\n", - "\n", - " # num_cycles = start_freq*d/44000.\n", - " # scaling_ind = np.log(end_freq/start_freq)/k\n", - "\n", - " # Choosing window shape\n", - "\n", - " #window_mask = get_window(window, int(win_length), fftbins=True)\n", - " window_mask = np.hamming(int(win_length))\n", - " window_mask = pad_center(window_mask, n_fft)\n", - "\n", - " if freq_scale == 'linear':\n", - " if verbose==True:\n", - " print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n", - " f\"get a valid freq range\")\n", - " \n", - " start_bin = start_freq*n_fft/sr\n", - " scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins\n", - "\n", - " for k in range(freq_bins): # Only half of the bins contain useful info\n", - " # print(\"linear freq = {}\".format((k*scaling_ind+start_bin)*sr/n_fft))\n", - " bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)\n", - " binslist.append((k*scaling_ind+start_bin))\n", - " wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n", - "\n", - " elif freq_scale == 'log':\n", - " if verbose==True:\n", - " print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n", - " f\"get a valid freq range\")\n", - " start_bin = start_freq*n_fft/sr\n", - " scaling_ind = np.log(end_freq/start_freq)/freq_bins\n", - "\n", - " for k in range(freq_bins): # Only half of the bins contain useful info\n", - " # print(\"log freq = {}\".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))\n", - " bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)\n", - " binslist.append((np.exp(k*scaling_ind)*start_bin))\n", - " wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n", - "\n", - " elif freq_scale == 'no':\n", - " for k in range(freq_bins): # Only half of the bins contain useful info\n", - " bins2freq.append(k*sr/n_fft)\n", - " binslist.append(k)\n", - " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n", - " else:\n", - " print(\"Please select the correct frequency scale, 'linear' or 'log'\")\n", - " return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)\n", - "\n", - "\n", - "\n", - "def broadcast_dim(x):\n", - " \"\"\"\n", - " Auto broadcast input so that it can fits into a Conv1d\n", - " \"\"\"\n", - "\n", - " if x.dim() == 2:\n", - " x = x[:, None, :]\n", - " elif x.dim() == 1:\n", - " # If nn.DataParallel is used, this broadcast doesn't work\n", - " x = x[None, None, :]\n", - " elif x.dim() == 3:\n", - " pass\n", - " else:\n", - " raise ValueError(\"Only support input with shape = (batch, len) or shape = (len)\")\n", - " return x\n", - "\n", - "\n", - "\n", - "### --------------------------- Spectrogram Classes ---------------------------###\n", - "class STFT(torch.nn.Module):\n", - "\n", - " def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',\n", - " freq_scale='no', center=True, pad_mode='reflect', iSTFT=False,\n", - " fmin=50, fmax=6000, sr=22050, trainable=False,\n", - " output_format=\"Complex\", verbose=True):\n", - "\n", - " super().__init__()\n", - "\n", - " # Trying to make the default setting same as librosa\n", - " if win_length==None: win_length = n_fft\n", - " if hop_length==None: hop_length = int(win_length // 4)\n", - "\n", - " self.output_format = output_format\n", - " self.trainable = trainable\n", - " self.stride = hop_length\n", - " self.center = center\n", - " self.pad_mode = pad_mode\n", - " self.n_fft = n_fft\n", - " self.freq_bins = freq_bins\n", - " self.trainable = trainable\n", - " self.pad_amount = self.n_fft // 2\n", - " self.window = window\n", - " self.win_length = win_length\n", - " self.iSTFT = iSTFT\n", - " self.trainable = trainable\n", - " start = time()\n", - "\n", - "\n", - "\n", - " # Create filter windows for stft\n", - " kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft,\n", - " win_length=win_length,\n", - " freq_bins=freq_bins,\n", - " window=window,\n", - " freq_scale=freq_scale,\n", - " fmin=fmin,\n", - " fmax=fmax,\n", - " sr=sr,\n", - " verbose=verbose)\n", - "\n", - "\n", - " kernel_sin = torch.tensor(kernel_sin, dtype=torch.float)\n", - " kernel_cos = torch.tensor(kernel_cos, dtype=torch.float)\n", - " \n", - " # In this way, the inverse kernel and the forward kernel do not share the same memory...\n", - " kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0)\n", - " kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0)\n", - " \n", - " if iSTFT:\n", - " self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1))\n", - " self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1))\n", - "\n", - " # Applying window functions to the Fourier kernels\n", - " if window:\n", - " window_mask = torch.tensor(window_mask)\n", - " wsin = kernel_sin * window_mask\n", - " wcos = kernel_cos * window_mask\n", - " else:\n", - " wsin = kernel_sin\n", - " wcos = kernel_cos\n", - " \n", - " if self.trainable==False:\n", - " self.register_buffer('wsin', wsin)\n", - " self.register_buffer('wcos', wcos) \n", - " \n", - " if self.trainable==True:\n", - " wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable)\n", - " wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable) \n", - " self.register_parameter('wsin', wsin)\n", - " self.register_parameter('wcos', wcos) \n", - " \n", - " # Prepare the shape of window mask so that it can be used later in inverse\n", - " # self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))\n", - " \n", - " if verbose==True:\n", - " print(\"STFT kernels created, time used = {:.4f} seconds\".format(time()-start))\n", - " else:\n", - " pass\n", - "\n", - " def forward(self, x, output_format=None):\n", - " \"\"\"\n", - " Convert a batch of waveforms to spectrograms.\n", - " \n", - " Parameters\n", - " ----------\n", - " x : torch tensor\n", - " Input signal should be in either of the following shapes.\\n\n", - " 1. ``(len_audio)``\\n\n", - " 2. ``(num_audio, len_audio)``\\n\n", - " 3. ``(num_audio, 1, len_audio)``\n", - " It will be automatically broadcast to the right shape\n", - " \n", - " output_format : str\n", - " Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``.\n", - " Default value is ``Complex``. \n", - " \n", - " \"\"\"\n", - " output_format = output_format or self.output_format\n", - " self.num_samples = x.shape[-1]\n", - " \n", - " x = broadcast_dim(x)\n", - " if self.center:\n", - " if self.pad_mode == 'constant':\n", - " padding = nn.ConstantPad1d(self.pad_amount, 0)\n", - "\n", - " elif self.pad_mode == 'reflect':\n", - " if self.num_samples < self.pad_amount:\n", - " raise AssertionError(\"Signal length shorter than reflect padding length (n_fft // 2).\")\n", - " padding = nn.ReflectionPad1d(self.pad_amount)\n", - "\n", - " x = padding(x)\n", - " spec_imag = conv1d(x, self.wsin, stride=self.stride)\n", - " spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d\n", - "\n", - " # remove redundant parts\n", - " spec_real = spec_real[:, :self.freq_bins, :]\n", - " spec_imag = spec_imag[:, :self.freq_bins, :]\n", - "\n", - " if output_format=='Magnitude':\n", - " spec = spec_real.pow(2) + spec_imag.pow(2)\n", - " if self.trainable==True:\n", - " return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0\n", - " else:\n", - " return torch.sqrt(spec)\n", - "\n", - " elif output_format=='Complex':\n", - " return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part\n", - "\n", - " elif output_format=='Phase':\n", - " return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase\n", - "\n", - " def inverse(self, X, onesided=True, length=None, refresh_win=True):\n", - " \"\"\"\n", - " This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class, \n", - " which is to convert spectrograms back to waveforms. \n", - " It only works for the complex value spectrograms. If you have the magnitude spectrograms,\n", - " please use :func:`~nnAudio.Spectrogram.Griffin_Lim`. \n", - " \n", - " Parameters\n", - " ----------\n", - " onesided : bool\n", - " If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,\n", - " else use ``onesided=False``\n", - "\n", - " length : int\n", - " To make sure the inverse STFT has the same output length of the original waveform, please\n", - " set `length` as your intended waveform length. By default, ``length=None``,\n", - " which will remove ``n_fft//2`` samples from the start and the end of the output.\n", - " \n", - " refresh_win : bool\n", - " Recalculating the window sum square. If you have an input with fixed number of timesteps,\n", - " you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True``\n", - " \n", - " \n", - " \"\"\"\n", - " if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True):\n", - " raise NameError(\"Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`\") \n", - " \n", - " assert X.dim()==4 , \"Inverse iSTFT only works for complex number,\" \\\n", - " \"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2).\"\\\n", - " \"\\nIf you have a magnitude spectrogram, please consider using Griffin-Lim.\"\n", - " if onesided:\n", - " X = extend_fbins(X) # extend freq\n", - "\n", - " \n", - " X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]\n", - "\n", - " # broadcast dimensions to support 2D convolution\n", - " X_real_bc = X_real.unsqueeze(1)\n", - " X_imag_bc = X_imag.unsqueeze(1)\n", - " a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1))\n", - " b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1))\n", - " \n", - " # compute real and imag part. signal lies in the real part\n", - " real = a1 - b2\n", - " real = real.squeeze(-2)*self.window_mask\n", - "\n", - " # Normalize the amplitude with n_fft\n", - " real /= (self.n_fft)\n", - "\n", - " # Overlap and Add algorithm to connect all the frames\n", - " real = overlap_add(real, self.stride)\n", - " \n", - " # Prepare the window sumsqure for division\n", - " # Only need to create this window once to save time\n", - " # Unless the input spectrograms have different time steps\n", - " if hasattr(self, 'w_sum')==False or refresh_win==True:\n", - " self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()\n", - " self.nonzero_indices = (self.w_sum>1e-10) \n", - " else:\n", - " pass\n", - " real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])\n", - " # Remove padding\n", - " if length is None: \n", - " if self.center:\n", - " real = real[:, self.pad_amount:-self.pad_amount]\n", - "\n", - " else:\n", - " if self.center:\n", - " real = real[:, self.pad_amount:self.pad_amount + length] \n", - " else:\n", - " real = real[:, :length] \n", - " \n", - " return real\n", - " \n", - " def extra_repr(self) -> str:\n", - " return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format(\n", - " self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable\n", - " ) " - ] - }, - { - "cell_type": "code", - "execution_count": 128, - "id": "unusual-baker", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "16000\n", - "(83792,)\n", - "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.0153 seconds\n", - "torch.Size([521, 257])\n", - "(522, 257)\n", - "[[5.84560000e+04 2.55260664e+04 9.83611035e+03 ... 7.80710554e+00\n", - " 2.32206573e+01 1.90274487e+01]\n", - " [1.35420000e+04 3.47535000e+04 1.51204707e+04 ... 1.69094101e+02\n", - " 1.80534729e+02 1.84179596e+02]\n", - " [3.47560000e+04 2.83094609e+04 8.20204883e+03 ... 1.02080307e+02\n", - " 1.21321175e+02 1.08345497e+02]\n", - " ...\n", - " [9.36700000e+03 2.86213008e+04 1.41182402e+04 ... 1.19344498e+02\n", - " 1.25670158e+02 1.20691467e+02]\n", - " [2.87510000e+04 2.04348242e+04 8.76390625e+03 ... 9.74485092e+01\n", - " 9.01831894e+01 9.84055099e+01]\n", - " [4.45240000e+04 8.93593262e+03 4.39246826e+03 ... 6.16300154e+00\n", - " 8.94473553e+00 9.61348629e+00]]\n", - "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]\n", - " [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n", - " 6.36197377e+01 4.40000000e+01]]\n" - ] - } - ], - "source": [ - "wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", - "print(sr)\n", - "print(wav.shape)\n", - "\n", - "x = wav\n", - "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", - "\n", - "spec_layer = STFT(n_fft=512, win_length=400, hop_length=160,\n", - " window='', freq_scale='linear', center=False, pad_mode='constant',\n", - " fmin=0, fmax=8000, sr=sr, output_format='Magnitude')\n", - "wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", - "wav_spec = wav_spec[0].T\n", - "print(wav_spec.shape)\n", - "\n", - "\n", - "spec, rspec = fbank(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - " wintype='hamming')\n", - "print(spec.shape)\n", - "\n", - "print(wav_spec.numpy())\n", - "print(rspec)\n", - "# print(spec)\n", - "\n", - "# spec, rspec = fbank(wav, samplerate=16000,winlen=0.032,winstep=0.01,\n", - "# nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - "# dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - "# wintype='hamming')\n", - "# print(rspec)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "white-istanbul", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 129, - "id": "modern-rescue", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0. 0.11697778 0.41317591 0.75 0.96984631 0.96984631\n", - " 0.75 0.41317591 0.11697778 0. ]\n" - ] - }, - { - "data": { - "text/plain": [ - "array([0. , 0.0954915, 0.3454915, 0.6545085, 0.9045085, 1. ,\n", - " 0.9045085, 0.6545085, 0.3454915, 0.0954915])" - ] - }, - "execution_count": 129, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(np.hanning(10))\n", - "from scipy.signal import get_window\n", - "get_window('hann', 10, fftbins=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "professional-journalism", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 153, - "id": "involved-motion", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(522, 400)\n", - "[[ 43. 75. 69. ... 46. 46. 45.]\n", - " [ 210. 215. 216. ... -86. -89. -91.]\n", - " [ 128. 128. 128. ... -154. -151. -151.]\n", - " ...\n", - " [ -60. -61. -61. ... 112. 109. 110.]\n", - " [ 20. 22. 24. ... 91. 87. 87.]\n", - " [ 111. 107. 108. ... -6. -4. -8.]]\n", - "torch.Size([1, 1, 83792])\n", - "torch.Size([400, 1, 512])\n", - "torch.Size([1, 400, 521])\n", - "conv frame tensor([[ 43., 75., 69., ..., 46., 46., 45.],\n", - " [ 210., 215., 216., ..., -86., -89., -91.],\n", - " [ 128., 128., 128., ..., -154., -151., -151.],\n", - " ...,\n", - " [-143., -141., -142., ..., 96., 101., 101.],\n", - " [ -60., -61., -61., ..., 112., 109., 110.],\n", - " [ 20., 22., 24., ..., 91., 87., 87.]])\n", - "xx [[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n", - " 2.4064583e+01 2.2000000e+01]\n", - " [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n", - " 1.1877571e+02 1.6200000e+02]\n", - " [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n", - " 9.5781029e+01 1.4200000e+02]\n", - " ...\n", - " [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n", - " 9.1511757e+01 1.1500000e+02]\n", - " [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n", - " 7.8405365e+01 9.0000000e+01]\n", - " [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n", - " 5.1310158e+01 3.5000000e+01]]\n", - "torch.Size([521, 257])\n", - "yy [[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n", - " 9.15117270e+01 1.15000000e+02]\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]]\n", - "yy (522, 257)\n", - "[[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n", - " 2.4064583e+01 2.2000000e+01]\n", - " [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n", - " 1.1877571e+02 1.6200000e+02]\n", - " [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n", - " 9.5781029e+01 1.4200000e+02]\n", - " ...\n", - " [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n", - " 9.1511757e+01 1.1500000e+02]\n", - " [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n", - " 7.8405365e+01 9.0000000e+01]\n", - " [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n", - " 5.1310158e+01 3.5000000e+01]]\n", - "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n", - " 9.15117270e+01 1.15000000e+02]\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]]\n", - "False\n" - ] - } - ], - "source": [ - "f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - " wintype='hamming')\n", - "print(f.shape)\n", - "print(f)\n", - "\n", - "n_fft=512\n", - "freq_bins = n_fft//2+1\n", - "s = np.arange(0, n_fft, 1.)\n", - "wsin = np.empty((freq_bins,1,n_fft))\n", - "wcos = np.empty((freq_bins,1,n_fft))\n", - "for k in range(freq_bins): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n", - "\n", - "\n", - "wsin = np.empty((n_fft,1,n_fft))\n", - "wcos = np.empty((n_fft,1,n_fft))\n", - "for k in range(n_fft): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.eye(n_fft, n_fft)[k]\n", - " wcos[k,0,:] = np.eye(n_fft, n_fft)[k]\n", - " \n", - " \n", - "wsin = np.empty((400,1,n_fft))\n", - "wcos = np.empty((400,1,n_fft))\n", - "for k in range(400): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.eye(400, n_fft)[k]\n", - " wcos[k,0,:] = np.eye(400, n_fft)[k]\n", - " \n", - "\n", - " \n", - "x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n", - "x = x[None, None, :]\n", - "print(x.size())\n", - "kernel_sin = torch.tensor(wsin, dtype=torch.float)\n", - "kernel_cos = torch.tensor(wcos, dtype=torch.float)\n", - "print(kernel_sin.size())\n", - "\n", - "from torch.nn.functional import conv1d, conv2d, fold\n", - "spec_imag = conv1d(x, kernel_sin, stride=160)\n", - "spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n", - "\n", - "print(spec_imag.size())\n", - "print(\"conv frame\", spec_imag[0].T)\n", - "# print(spec_imag[0].T[:, :400])\n", - "\n", - "# remove redundant parts\n", - "# spec_real = spec_real[:, :freq_bins, :]\n", - "# spec_imag = spec_imag[:, :freq_bins, :]\n", - "# spec = spec_real.pow(2) + spec_imag.pow(2)\n", - "# spec = torch.sqrt(spec)\n", - "# print(spec)\n", - "\n", - "\n", - "\n", - "s = np.arange(0, 512, 1.)\n", - "# s = s[::-1]\n", - "wsin = np.empty((freq_bins, 400))\n", - "wcos = np.empty((freq_bins, 400))\n", - "for k in range(freq_bins): # Only half of the bins contain useful info\n", - " wsin[k,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n", - " wcos[k,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n", - "\n", - "spec_real = torch.mm(spec_imag[0].T, torch.tensor(wcos, dtype=torch.float).T)\n", - "spec_imag = torch.mm(spec_imag[0].T, torch.tensor(wsin, dtype=torch.float).T)\n", - "\n", - "\n", - "# remove redundant parts\n", - "spec = spec_real.pow(2) + spec_imag.pow(2)\n", - "spec = torch.sqrt(spec)\n", - "\n", - "print('xx', spec.numpy())\n", - "print(spec.size())\n", - "print('yy', rspec[:521, :])\n", - "print('yy', rspec.shape)\n", - "\n", - "\n", - "x = spec.numpy()\n", - "y = rspec[:-1, :]\n", - "print(x)\n", - "print(y)\n", - "print(np.allclose(x, y))" - ] - }, - { - "cell_type": "code", - "execution_count": 160, - "id": "mathematical-traffic", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([257, 1, 400])\n", - "tensor([[[5.8976e+04, 2.9266e+04, 1.9630e+04, ..., 1.6772e+04,\n", - " 3.8693e+04, 3.1020e+04],\n", - " [2.5101e+04, 2.7298e+04, 2.8117e+04, ..., 2.1323e+04,\n", - " 1.3598e+04, 1.5920e+04],\n", - " [8.5960e+03, 4.7724e+03, 5.2880e+03, ..., 4.0608e+02,\n", - " 6.7707e+03, 4.3020e+03],\n", - " ...,\n", - " [2.0282e+01, 6.6927e+01, 2.8501e+01, ..., 2.6012e+01,\n", - " 6.1071e+01, 5.3685e+01],\n", - " [2.4065e+01, 1.1878e+02, 9.5781e+01, ..., 7.8405e+01,\n", - " 5.1310e+01, 6.3620e+01],\n", - " [2.2000e+01, 1.6200e+02, 1.4200e+02, ..., 9.0000e+01,\n", - " 3.5000e+01, 4.4000e+01]]])\n", - "[[5.8976000e+04 2.5100672e+04 8.5960391e+03 ... 2.0281828e+01\n", - " 2.4064537e+01 2.2000000e+01]\n", - " [2.9266000e+04 2.7298107e+04 4.7724243e+03 ... 6.6926659e+01\n", - " 1.1877571e+02 1.6200000e+02]\n", - " [1.9630000e+04 2.8117475e+04 5.2880312e+03 ... 2.8501148e+01\n", - " 9.5781006e+01 1.4200000e+02]\n", - " ...\n", - " [1.6772000e+04 2.1322793e+04 4.0607657e+02 ... 2.6011934e+01\n", - " 7.8405350e+01 9.0000000e+01]\n", - " [3.8693000e+04 1.3598203e+04 6.7706841e+03 ... 6.1070808e+01\n", - " 5.1310150e+01 3.5000000e+01]\n", - " [3.1020000e+04 1.5920403e+04 4.3019902e+03 ... 5.3685162e+01\n", - " 6.3619797e+01 4.4000000e+01]]\n", - "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]\n", - " [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n", - " 6.36197377e+01 4.40000000e+01]]\n", - "False\n" - ] - } - ], - "source": [ - "f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - " wintype='hamming')\n", - "\n", - "n_fft=512\n", - "freq_bins = n_fft//2+1\n", - "s = np.arange(0, n_fft, 1.)\n", - "wsin = np.empty((freq_bins,1,400))\n", - "wcos = np.empty((freq_bins,1,400)) #[Cout, Cin, kernel_size]\n", - "for k in range(freq_bins): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n", - " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n", - "\n", - " \n", - "x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n", - "x = x[None, None, :] #[B, C, T]\n", - "\n", - "kernel_sin = torch.tensor(wsin, dtype=torch.float)\n", - "kernel_cos = torch.tensor(wcos, dtype=torch.float)\n", - "print(kernel_sin.size())\n", - "\n", - "from torch.nn.functional import conv1d, conv2d, fold\n", - "spec_imag = conv1d(x, kernel_sin, stride=160) #[1, Cout, T]\n", - "spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n", - "\n", - "# remove redundant parts\n", - "spec = spec_real.pow(2) + spec_imag.pow(2)\n", - "spec = torch.sqrt(spec)\n", - "print(spec)\n", - "\n", - "x = spec[0].T.numpy()\n", - "y = rspec[:, :]\n", - "print(x)\n", - "print(y)\n", - "print(np.allclose(x, y))" - ] - }, - { - "cell_type": "code", - "execution_count": 162, - "id": "olive-nicaragua", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: RuntimeWarning: divide by zero encountered in true_divide\n", - " \"\"\"Entry point for launching an IPython kernel.\n" - ] - }, - { - "data": { - "text/plain": [ - "27241" - ] - }, - "execution_count": 162, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.argmax(np.abs(x -y) / np.abs(y))" - ] - }, - { - "cell_type": "code", - "execution_count": 165, - "id": "ultimate-assault", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.0" - ] - }, - "execution_count": 165, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y[np.unravel_index(27241, y.shape)]" - ] - }, - { - "cell_type": "code", - "execution_count": 166, - "id": "institutional-stock", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4.2412265e-10" - ] - }, - "execution_count": 166, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x[np.unravel_index(27241, y.shape)]" - ] - }, - { - "cell_type": "code", - "execution_count": 167, - "id": "integrated-courage", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 167, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.allclose(y, x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "different-operation", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/compute_cmvn_loader_test.ipynb b/.notebook/compute_cmvn_loader_test.ipynb deleted file mode 100644 index 2b0a8b75f..000000000 --- a/.notebook/compute_cmvn_loader_test.ipynb +++ /dev/null @@ -1,793 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "purple-consequence", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/home/ssd5/zhanghui/DeepSpeech2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "defensive-mason", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "patient-convention", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Namespace(delta_delta=False, feat_dim=80, manifest_path='examples/aishell/s1/data/manifest.train.raw', num_samples=-1, num_workers=16, output_path='data/librispeech/mean_std.npz', sample_rate=16000, specgram_type='fbank', stride_ms=10.0, window_ms=25.0)\n" - ] - } - ], - "source": [ - "import argparse\n", - "import functools\n", - "\n", - "from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n", - "from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n", - "from deepspeech.frontend.normalizer import FeatureNormalizer\n", - "from deepspeech.utils.utility import add_arguments\n", - "from deepspeech.utils.utility import print_arguments\n", - "\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, -1, \"# of samples to for statistics.\")\n", - "add_arg('specgram_type', str,\n", - " 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc, fbank.\",\n", - " choices=['linear', 'mfcc', 'fbank'])\n", - "add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n", - "add_arg('delta_delta', bool,\n", - " False,\n", - " \"Audio feature with delta delta.\")\n", - "add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n", - "add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n", - "add_arg('sample_rate', int, 16000, \"target sample rate.\")\n", - "add_arg('manifest_path', str,\n", - " 'examples/aishell/s1/data/manifest.train.raw',\n", - " \"Filepath of manifest to compute normalizer's mean and stddev.\")\n", - "add_arg('num_workers',\n", - " default=16,\n", - " type=int,\n", - " help='num of subprocess workers for processing')\n", - "add_arg('output_path', str,\n", - " 'data/librispeech/mean_std.npz',\n", - " \"Filepath of write mean and stddev to (.npz).\")\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(args)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "enormous-currency", - "metadata": {}, - "outputs": [], - "source": [ - "import random\n", - "\n", - "import numpy as np\n", - "import paddle\n", - "from paddle.io import DataLoader\n", - "from paddle.io import Dataset\n", - "\n", - "from deepspeech.frontend.audio import AudioSegment\n", - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.frontend.utility import read_manifest\n", - "\n", - "class CollateFunc(object):\n", - " ''' Collate function for AudioDataset\n", - " '''\n", - " def __init__(self):\n", - " pass\n", - " \n", - " def __call__(self, batch):\n", - " mean_stat = None\n", - " var_stat = None\n", - " number = 0\n", - " for feat in batch:\n", - " sums = np.sum(feat, axis=1)\n", - " if mean_stat is None:\n", - " mean_stat = sums\n", - " else:\n", - " mean_stat += sums\n", - "\n", - " square_sums = np.sum(np.square(feat), axis=1)\n", - " if var_stat is None:\n", - " var_stat = square_sums\n", - " else:\n", - " var_stat += square_sums\n", - "\n", - " number += feat.shape[1]\n", - " #return paddle.to_tensor(number), paddle.to_tensor(mean_stat), paddle.to_tensor(var_stat)\n", - " return number, mean_stat, var_stat\n", - "\n", - "\n", - "class AudioDataset(Dataset):\n", - " def __init__(self, manifest_path, feature_func, num_samples=-1, rng=None):\n", - " self.feature_func = feature_func\n", - " self._rng = rng\n", - " manifest = read_manifest(manifest_path)\n", - " if num_samples == -1:\n", - " sampled_manifest = manifest\n", - " else:\n", - " sampled_manifest = self._rng.sample(manifest, num_samples)\n", - " self.items = sampled_manifest\n", - "\n", - " def __len__(self):\n", - " return len(self.items)\n", - "\n", - " def __getitem__(self, idx):\n", - " key = self.items[idx]['feat']\n", - " audioseg = AudioSegment.from_file(key)\n", - " feat = self.feature_func(audioseg) #(D, T)\n", - " return feat" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "armed-semester", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "process 1000 wavs,450739 frames\n", - "process 2000 wavs,887447 frames\n", - "process 3000 wavs,1354148 frames\n", - "process 4000 wavs,1816494 frames\n", - "process 5000 wavs,2359211 frames\n", - "process 6000 wavs,2828455 frames\n", - "process 7000 wavs,3276186 frames\n", - "process 8000 wavs,3692234 frames\n", - "process 9000 wavs,4139360 frames\n", - "process 10000 wavs,4591528 frames\n", - "process 11000 wavs,5020114 frames\n", - "process 12000 wavs,5459523 frames\n", - "process 13000 wavs,5899534 frames\n", - "process 14000 wavs,6323242 frames\n", - "process 15000 wavs,6736597 frames\n", - "process 16000 wavs,7207686 frames\n", - "process 17000 wavs,7637800 frames\n", - "process 18000 wavs,8093004 frames\n", - "process 19000 wavs,8529518 frames\n", - "process 20000 wavs,8906022 frames\n", - "process 21000 wavs,9352652 frames\n", - "process 22000 wavs,9807495 frames\n", - "process 23000 wavs,10247938 frames\n", - "process 24000 wavs,10700011 frames\n", - "process 25000 wavs,11126134 frames\n", - "process 26000 wavs,11558061 frames\n", - "process 27000 wavs,12010359 frames\n", - "process 28000 wavs,12470938 frames\n", - "process 29000 wavs,12916013 frames\n", - "process 30000 wavs,13345816 frames\n", - "process 31000 wavs,13752365 frames\n", - "process 32000 wavs,14174801 frames\n", - "process 33000 wavs,14642170 frames\n", - "process 34000 wavs,15053557 frames\n", - "process 35000 wavs,15531890 frames\n", - "process 36000 wavs,16022711 frames\n", - "process 37000 wavs,16437688 frames\n", - "process 38000 wavs,16859517 frames\n", - "process 39000 wavs,17307676 frames\n", - "process 40000 wavs,17796629 frames\n", - "process 41000 wavs,18264151 frames\n", - "process 42000 wavs,18711898 frames\n", - "process 43000 wavs,19159890 frames\n", - "process 44000 wavs,19576435 frames\n", - "process 45000 wavs,19992793 frames\n", - "process 46000 wavs,20464449 frames\n", - "process 47000 wavs,20886021 frames\n", - "process 48000 wavs,21317318 frames\n", - "process 49000 wavs,21738034 frames\n", - "process 50000 wavs,22171890 frames\n", - "process 51000 wavs,22622238 frames\n", - "process 52000 wavs,23100734 frames\n", - "process 53000 wavs,23526901 frames\n", - "process 54000 wavs,23969746 frames\n", - "process 55000 wavs,24418691 frames\n", - "process 56000 wavs,24862546 frames\n", - "process 57000 wavs,25336448 frames\n", - "process 58000 wavs,25778435 frames\n", - "process 59000 wavs,26216199 frames\n", - "process 60000 wavs,26694692 frames\n", - "process 61000 wavs,27148978 frames\n", - "process 62000 wavs,27617088 frames\n", - "process 63000 wavs,28064946 frames\n", - "process 64000 wavs,28519843 frames\n", - "process 65000 wavs,28989722 frames\n", - "process 66000 wavs,29470156 frames\n", - "process 67000 wavs,29952931 frames\n", - "process 68000 wavs,30360555 frames\n", - "process 69000 wavs,30797929 frames\n", - "process 70000 wavs,31218227 frames\n", - "process 71000 wavs,31663934 frames\n", - "process 72000 wavs,32107468 frames\n", - "process 73000 wavs,32541943 frames\n", - "process 74000 wavs,33010702 frames\n", - "process 75000 wavs,33448082 frames\n", - "process 76000 wavs,33886812 frames\n", - "process 77000 wavs,34338108 frames\n", - "process 78000 wavs,34761495 frames\n", - "process 79000 wavs,35199730 frames\n", - "process 80000 wavs,35669630 frames\n", - "process 81000 wavs,36122402 frames\n", - "process 82000 wavs,36604561 frames\n", - "process 83000 wavs,37085552 frames\n", - "process 84000 wavs,37517500 frames\n", - "process 85000 wavs,37987196 frames\n", - "process 86000 wavs,38415721 frames\n", - "process 87000 wavs,38889467 frames\n", - "process 88000 wavs,39337809 frames\n", - "process 89000 wavs,39792342 frames\n", - "process 90000 wavs,40287946 frames\n", - "process 91000 wavs,40719461 frames\n", - "process 92000 wavs,41178919 frames\n", - "process 93000 wavs,41659635 frames\n", - "process 94000 wavs,42132985 frames\n", - "process 95000 wavs,42584564 frames\n", - "process 96000 wavs,43018598 frames\n", - "process 97000 wavs,43480662 frames\n", - "process 98000 wavs,43973670 frames\n", - "process 99000 wavs,44448190 frames\n", - "process 100000 wavs,44935034 frames\n", - "process 101000 wavs,45379812 frames\n", - "process 102000 wavs,45821207 frames\n", - "process 103000 wavs,46258420 frames\n", - "process 104000 wavs,46743733 frames\n", - "process 105000 wavs,47206922 frames\n", - "process 106000 wavs,47683041 frames\n", - "process 107000 wavs,48122809 frames\n", - "process 108000 wavs,48594623 frames\n", - "process 109000 wavs,49086358 frames\n", - "process 110000 wavs,49525568 frames\n", - "process 111000 wavs,49985820 frames\n", - "process 112000 wavs,50428262 frames\n", - "process 113000 wavs,50897957 frames\n", - "process 114000 wavs,51344589 frames\n", - "process 115000 wavs,51774621 frames\n", - "process 116000 wavs,52243372 frames\n", - "process 117000 wavs,52726025 frames\n", - "process 118000 wavs,53170026 frames\n", - "process 119000 wavs,53614141 frames\n", - "process 120000 wavs,54071271 frames\n" - ] - } - ], - "source": [ - "\n", - "augmentation_pipeline = AugmentationPipeline('{}')\n", - "audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=True,\n", - " target_dB=-20)\n", - "\n", - "def augment_and_featurize(audio_segment):\n", - " augmentation_pipeline.transform_audio(audio_segment)\n", - " return audio_featurizer.featurize(audio_segment)\n", - "\n", - "\n", - "collate_func = CollateFunc()\n", - "\n", - "dataset = AudioDataset(\n", - " args.manifest_path,\n", - " augment_and_featurize, \n", - " args.num_samples)\n", - "\n", - "batch_size = 20\n", - "data_loader = DataLoader(\n", - " dataset,\n", - " batch_size=batch_size,\n", - " shuffle=False,\n", - " num_workers=args.num_workers,\n", - " collate_fn=collate_func)\n", - "\n", - "with paddle.no_grad():\n", - " all_mean_stat = None\n", - " all_var_stat = None\n", - " all_number = 0\n", - " wav_number = 0\n", - " for i, batch in enumerate(data_loader()):\n", - " #for batch in data_loader():\n", - " number, mean_stat, var_stat = batch\n", - " if i == 0:\n", - " all_mean_stat = mean_stat\n", - " all_var_stat = var_stat\n", - " else:\n", - " all_mean_stat += mean_stat\n", - " all_var_stat += var_stat\n", - " all_number += number\n", - " wav_number += batch_size\n", - "\n", - " if wav_number % 1000 == 0:\n", - " print('process {} wavs,{} frames'.format(wav_number,\n", - " all_number))\n", - "\n", - "cmvn_info = {\n", - " 'mean_stat': list(all_mean_stat.tolist()),\n", - " 'var_stat': list(all_var_stat.tolist()),\n", - " 'frame_num': all_number\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "danish-executive", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'mean_stat': [-813852467.7953382, -769025957.9140725, -809499593.411409, -774700574.014532, -750961217.5896736, -760564397.2864963, -805662399.3771614, -843490965.4231446, -850242081.9416809, -857678651.504435, -879067453.9826999, -908602072.3856701, -936850957.7187386, -957242686.489041, -968425442.0916103, -972687545.5953809, -980383731.7683417, -991533337.6343704, -1001966818.1164789, -1010334169.7486078, -1016855066.9099333, -1022176245.7021623, -1025700476.4788507, -1030678878.3195274, -1037075963.124199, -1042705719.0195516, -1047422212.6492896, -1049003537.271861, -1050314833.7453628, -1050772191.0204058, -1050010034.9948177, -1050436065.1336465, -1053327181.7978873, -1058710548.2036785, -1065950852.4966162, -1071709705.0060445, -1077682778.259181, -1083371045.272074, -1089708906.2657735, -1096312217.7865202, -1101089858.8364556, -1104965332.4332569, -1107791702.5223634, -1109431075.2374773, -1110066333.0280604, -1110382732.0722318, -1110480306.3793216, -1110203297.7110727, -1109972534.3583376, -1109378081.8792782, -1108212059.413654, -1107235713.2041805, -1106973581.9280007, -1107352339.7860134, -1108730029.862537, -1110425202.83704, -1113220669.4552443, -1115887535.4870913, -1118105356.3628063, -1120001376.8503075, -1121135822.320366, -1122265971.8751016, -1123990217.401155, -1125786729.6230593, -1127784957.2745507, -1129180108.9033566, -1132000461.6688302, -1134675829.8190608, -1137652487.5164194, -1141755948.0463965, -1145340901.5468378, -1148637682.593287, -1151755522.470022, -1154981643.2268832, -1157417488.840151, -1161240429.0989249, -1165411128.671642, -1170521097.1034513, -1176307165.5109766, -1183456865.0039694, -1190535938.6591117, -1197946309.0472982, -1203596565.037139, -1207563038.1241052, -1209707561.5829782, -1211407066.2452552, -1211884576.9201162, -1212778872.005509, -1214041413.8080075, -1215367953.1745043, -1216850831.482193, -1217678325.5351057, -1218854289.54188, -1219325064.8610544, -1219080344.7580786, -1218541313.657531, -1217889833.2067819, -1216552930.1654336, -1216423777.4113154, -1216575252.225508, -1217075384.9826024, -1217391577.901724, -1217838974.57273, -1218131805.6054134, -1218294889.7465532, -1218566666.1755593, -1218790537.5519717, -1218748668.9956846, -1218603191.4941735, -1218004566.4348054, -1217312410.127734, -1217207493.9522285, -1217284002.3834674, -1217644312.51745, -1218039821.6444128, -1218721811.6269798, -1219121088.9265897, -1219014460.8090584, -1218530127.6776083, -1217952335.451711, -1217316073.8666434, -1217035380.1151958, -1216636431.2964456, -1216257015.2945514, -1215658496.1208403, -1215097272.0976632, -1214669859.2064147, -1214593853.4809475, -1214599475.7838447, -1214575440.823035, -1214158828.8008435, -1213482920.2673717, -1212476577.5897374, -1211251374.2198513, -1210284855.590475, -1209302456.065669, -1209106252.6625297, -1209373211.5146718, -1209689421.7984035, -1210021342.495856, -1210650609.3592312, -1211428521.3900626, -1212616111.4257205, -1213820075.2948189, -1215320588.7144456, -1217175082.2739282, -1219703351.4585004, -1222007827.120464, -1224637375.5900724, -1228367798.912171, -1234853879.862459, -1247222219.867692, -1268562808.1616178, -1302034822.9569275, -1347823631.0776038, -1402753916.9445229, -1458826717.3262982, -1505843092.0970414, -1534278782.249077, -1543955545.8994718, -1600409154.893352], 'var_stat': [12665413908.91729, 11145088801.244318, 12567119446.035736, 11758392758.06822, 11200687982.736668, 11551903443.711124, 12880777868.435602, 14084854368.236998, 14394011058.866192, 14678818621.277662, 15346278722.626339, 16268053979.757076, 17191705347.854794, 17877540386.548733, 18251857849.077663, 18392628178.710472, 18645534548.4045, 19018598212.22902, 19366711357.782673, 19655730286.72857, 19890681996.786858, 20094163350.461906, 20227774955.225887, 20423525628.66887, 20669928826.76939, 20882313568.247944, 21062392676.270527, 21126648821.879055, 21185210734.751118, 21209014745.520447, 21182293842.91236, 21197433134.875977, 21302147790.662144, 21504666657.651955, 21781818550.89697, 21996170165.145462, 22217169779.096275, 22431161762.176693, 22672708668.38104, 22922683961.072956, 23101137011.201683, 23249680793.556847, 23358894817.24979, 23422895267.919228, 23449479198.303394, 23464433357.671055, 23469197140.124596, 23459013479.866177, 23447935341.542686, 23422585038.052387, 23375601301.949135, 23338397991.497776, 23329682884.21905, 23348002892.39853, 23406274659.89975, 23478242518.92228, 23592891371.876236, 23703885161.772205, 23797158601.65954, 23875230355.66992, 23918333664.3946, 23968582109.371258, 24040547318.081936, 24112364295.110058, 24189973697.612144, 24242165205.640236, 24364255205.82311, 24472408850.760197, 24590211203.05312, 24763026764.005527, 24909192634.69144, 25043438176.23281, 25167141466.500504, 25297108031.48665, 25395377064.0999, 25550930772.86505, 25721404827.10336, 25931101211.156487, 26168988710.098465, 26465528802.762875, 26760033029.443783, 27075408488.605213, 27316626931.655052, 27487275073.52796, 27579518448.2332, 27652308513.875782, 27673412508.45838, 27711509210.702576, 27767312240.641487, 27827464683.295334, 27894794590.957966, 27935988489.16511, 27992337099.891083, 28019655483.58796, 28014286886.252903, 27996189233.857716, 27973078840.875465, 27920045013.68706, 27917103211.22359, 27927566165.64652, 27953525818.61368, 27973386070.140022, 27999317832.502476, 28019494120.641834, 28033010746.452637, 28051086123.896503, 28066195174.191753, 28068570977.318798, 28064890246.85437, 28042424375.860577, 28015849655.869568, 28014812222.566605, 28021039053.959835, 28039270607.169422, 28058271295.10199, 28088976520.10178, 28107824988.74732, 28105633030.784756, 28087681357.818607, 28065484299.963837, 28039555887.004284, 28028214431.52875, 28011714871.929447, 27995603790.480755, 27970125897.561134, 27946436130.511288, 27929044772.5522, 27926612443.390316, 27926256324.387302, 27924771848.71099, 27905526922.390133, 27876268519.168198, 27832532606.552593, 27779497699.976765, 27737034351.907337, 27692129825.179924, 27684252911.371475, 27698882622.878677, 27712387157.27985, 27726474638.933037, 27752647691.051613, 27786197932.382797, 27836378752.662235, 27887415700.334576, 27949784230.702114, 28028117657.84245, 28136313097.200474, 28234098926.207996, 28345845477.25874, 28507222800.146496, 28793832339.90449, 29350765483.070816, 30328262350.231213, 31894930713.76519, 34093669067.422382, 36801959396.22739, 39638995447.49344, 42088579425.44825, 43616108982.85117, 44152063315.31461, 47464832889.5967], 'frame_num': 54129649}\n" - ] - } - ], - "source": [ - "print(cmvn_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "accurate-terminal", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "dominant-abuse", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " \n", - "process 1000 wavs,450240 frames\n", - " \n", - "process 2000 wavs,886411 frames\n", - " \n", - "process 3000 wavs,1352580 frames\n", - " \n", - "process 4000 wavs,1814397 frames\n", - " \n", - "process 5000 wavs,2356587 frames\n", - " \n", - "process 6000 wavs,2825310 frames\n", - " \n", - "process 7000 wavs,3272506 frames\n", - " \n", - "process 8000 wavs,3688045 frames\n", - " \n", - "process 9000 wavs,4134669 frames\n", - " \n", - "process 10000 wavs,4586357 frames\n", - " \n", - "process 11000 wavs,5014429 frames\n", - " \n", - "process 12000 wavs,5453334 frames\n", - " \n", - "process 13000 wavs,5892888 frames\n", - " \n", - "process 14000 wavs,6316059 frames\n", - " \n", - "process 15000 wavs,6728870 frames\n", - " \n", - "process 16000 wavs,7199442 frames\n", - " \n", - "process 17000 wavs,7629055 frames\n", - " \n", - "process 18000 wavs,8083729 frames\n", - " \n", - "process 19000 wavs,8519732 frames\n", - " \n", - "process 20000 wavs,8895694 frames\n", - " \n", - "process 21000 wavs,9341778 frames\n", - " \n", - "process 22000 wavs,9796126 frames\n", - " \n", - "process 23000 wavs,10236057 frames\n", - " \n", - "process 24000 wavs,10687461 frames\n", - " \n", - "process 25000 wavs,11113082 frames\n", - " \n", - "process 26000 wavs,11544482 frames\n", - " \n", - "process 27000 wavs,11996273 frames\n", - " \n", - "process 28000 wavs,12456350 frames\n", - " \n", - "process 29000 wavs,12900895 frames\n", - " \n", - "process 30000 wavs,13330353 frames\n", - " \n", - "process 31000 wavs,13736568 frames\n", - " \n", - "process 32000 wavs,14158472 frames\n", - " \n", - "process 33000 wavs,14625316 frames\n", - " \n", - "process 34000 wavs,15036206 frames\n", - " \n", - "process 35000 wavs,15514001 frames\n", - " \n", - "process 36000 wavs,16004323 frames\n", - " \n", - "process 37000 wavs,16418799 frames\n", - " \n", - "process 38000 wavs,16840100 frames\n", - " \n", - "process 39000 wavs,17287752 frames\n", - " \n", - "process 40000 wavs,17776206 frames\n", - " \n", - "process 41000 wavs,18243209 frames\n", - " \n", - "process 42000 wavs,18690449 frames\n", - " \n", - "process 43000 wavs,19137940 frames\n", - " \n", - "process 44000 wavs,19553966 frames\n", - " \n", - "process 45000 wavs,19969813 frames\n", - " \n", - "process 46000 wavs,20440963 frames\n", - " \n", - "process 47000 wavs,20862022 frames\n", - " \n", - "process 48000 wavs,21292801 frames\n", - " \n", - "process 49000 wavs,21713004 frames\n", - " \n", - "process 50000 wavs,22146346 frames\n", - " \n", - "process 51000 wavs,22596172 frames\n", - " \n", - "process 52000 wavs,23074160 frames\n", - " \n", - "process 53000 wavs,23499823 frames\n", - " \n", - "process 54000 wavs,23942151 frames\n", - " \n", - "process 55000 wavs,24390566 frames\n", - " \n", - "process 56000 wavs,24833905 frames\n", - " \n", - "process 57000 wavs,25307270 frames\n", - " \n", - "process 58000 wavs,25748720 frames\n", - " \n", - "process 59000 wavs,26185964 frames\n", - " \n", - "process 60000 wavs,26663953 frames\n", - " \n", - "process 61000 wavs,27117720 frames\n", - " \n", - "process 62000 wavs,27585349 frames\n", - " \n", - "process 63000 wavs,28032693 frames\n", - " \n", - "process 64000 wavs,28487074 frames\n", - " \n", - "process 65000 wavs,28956462 frames\n", - " \n", - "process 66000 wavs,29436358 frames\n", - " \n", - "process 67000 wavs,29918569 frames\n", - " \n", - "process 68000 wavs,30325682 frames\n", - " \n", - "process 69000 wavs,30762528 frames\n", - " \n", - "process 70000 wavs,31182319 frames\n", - " \n", - "process 71000 wavs,31627526 frames\n", - " \n", - "process 72000 wavs,32070556 frames\n", - " \n", - "process 73000 wavs,32504534 frames\n", - " \n", - "process 74000 wavs,32972775 frames\n", - " \n", - "process 75000 wavs,33409637 frames\n", - " \n", - "process 76000 wavs,33847861 frames\n", - " \n", - "process 77000 wavs,34298647 frames\n", - " \n", - "process 78000 wavs,34721536 frames\n", - " \n", - "process 79000 wavs,35159236 frames\n", - " \n", - "process 80000 wavs,35628628 frames\n", - " \n", - "process 81000 wavs,36080909 frames\n", - " \n", - "process 82000 wavs,36562496 frames\n", - " \n", - "process 83000 wavs,37042976 frames\n", - " \n", - "process 84000 wavs,37474403 frames\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " \n", - "process 85000 wavs,37943596 frames\n", - " \n", - "process 86000 wavs,38371620 frames\n", - " \n", - "process 87000 wavs,38844874 frames\n", - " \n", - "process 88000 wavs,39292686 frames\n", - " \n", - "process 89000 wavs,39746715 frames\n", - " \n", - "process 90000 wavs,40241800 frames\n", - " \n", - "process 91000 wavs,40672817 frames\n", - " \n", - "process 92000 wavs,41131773 frames\n", - " \n", - "process 93000 wavs,41612001 frames\n", - " \n", - "process 94000 wavs,42084822 frames\n", - " \n", - "process 95000 wavs,42535878 frames\n", - " \n", - "process 96000 wavs,42969365 frames\n", - " \n", - "process 97000 wavs,43430890 frames\n", - " \n", - "process 98000 wavs,43923378 frames\n", - " \n", - "process 99000 wavs,44397370 frames\n", - " \n", - "process 100000 wavs,44883695 frames\n", - " \n", - "process 101000 wavs,45327968 frames\n", - " \n", - "process 102000 wavs,45768860 frames\n", - " \n", - "process 103000 wavs,46205602 frames\n", - " \n", - "process 104000 wavs,46690407 frames\n", - " \n", - "process 105000 wavs,47153089 frames\n", - " \n", - "process 106000 wavs,47628699 frames\n", - " \n", - "process 107000 wavs,48067945 frames\n", - " \n", - "process 108000 wavs,48539256 frames\n", - " \n", - "process 109000 wavs,49030485 frames\n", - " \n", - "process 110000 wavs,49469189 frames\n", - " \n", - "process 111000 wavs,49928968 frames\n", - " \n", - "process 112000 wavs,50370921 frames\n", - " \n", - "process 113000 wavs,50840090 frames\n", - " \n", - "process 114000 wavs,51286249 frames\n", - " \n", - "process 115000 wavs,51715786 frames\n", - " \n", - "process 116000 wavs,52184017 frames\n", - " \n", - "process 117000 wavs,52666156 frames\n", - " \n", - "process 118000 wavs,53109645 frames\n", - " \n", - "process 119000 wavs,53553253 frames\n", - " \n", - "process 120000 wavs,54009877 frames\n", - "{'mean_stat': [700612678.1184504, 704246512.9321843, 720430663.1822729, 754033269.0474415, 798737761.616614, 829467218.4204571, 851246702.9426627, 862261185.2661449, 859339943.6923889, 846303730.8696194, 832995109.605447, 823196536.6029147, 832626008.2569772, 845571326.1936859, 848801373.0562981, 846503549.328017, 836774344.5500796, 823481091.0445303, 820728368.2518216, 804571348.4957463, 795306095.0083207, 811729024.2415155, 805734803.5703195, 813076782.1959459, 806620199.406499, 809655573.8886961, 804371708.9347517, 809272248.6085774, 810322689.7490631, 814294131.1973915, 816262716.0476038, 816213124.2411841, 817158473.4380915, 821414211.5629157, 827408091.5728914, 834353896.0519086, 840094990.3467333, 842613218.6554606, 842070761.1727513, 834970952.5260613, 837020570.8200948, 829592602.7833654, 830116543.8893851, 829482316.3881509, 833397219.4597517, 839251633.3120549, 845475010.4718693, 852378426.7183967, 859563981.8633184, 866063840.5523493, 867790921.9978689, 868215100.5962687, 869683066.032885, 872467375.6674014, 873097681.1780069, 873025823.0543871, 869897292.7201596, 866386426.3869117, 863166726.7256871, 854653071.2244718, 842402803.9000899, 830838253.4144138, 830143002.3536818, 831492285.0310817, 833304371.8781006, 838896092.8621838, 843866088.9578133, 847316792.1429776, 851038022.3643295, 855931698.0149751, 859320543.9795249, 863031001.3470656, 868325062.1832993, 873626971.0115026, 878726636.924209, 884861725.972504, 886920281.5192285, 883056006.5094173, 863719240.7255149, 773378975.9476194], 'var_stat': [9237018652.657722, 9417257721.82426, 10105084297.159702, 11071318522.587782, 12422783727.426847, 13400306419.784964, 14148498843.406874, 14576436982.89939, 14529009036.494726, 14105645932.596651, 13682988821.478252, 13413013425.088106, 13764134927.293928, 14233704806.737064, 14361631309.367067, 14281358385.45644, 13939662689.213865, 13496884231.929493, 13382566162.783987, 12871350930.6626, 12576198160.876635, 13051463889.56708, 12859205935.513906, 13053861416.098743, 12830323588.550724, 12886405923.897238, 12708529922.84171, 12847306110.231739, 12880398489.53404, 13002566299.565536, 13066708060.463543, 13064231286.858614, 13088983337.353497, 13221393824.891022, 13412425607.755072, 13631485149.777075, 13807797519.156103, 13877277485.033077, 13848613909.96762, 13609176326.2529, 13649815250.130072, 13397698404.696907, 13388964704.359968, 13354326914.968012, 13469861474.898457, 13652539440.283333, 13846837321.329163, 14062143714.601675, 14292571198.61228, 14504626563.299246, 14563864749.132776, 14579720287.991764, 14626700787.353922, 14716185568.128899, 14728532777.28015, 14719101187.113443, 14607945896.239174, 14478517828.531614, 14355110561.681187, 14057430280.249746, 13634284490.879377, 13248236002.494394, 13217602306.335958, 13257856701.946049, 13323688441.072674, 13515395318.023148, 13685827169.67645, 13811622609.426846, 13947347160.615082, 14115883822.884943, 14231204526.433033, 14356066668.651815, 14533604268.238445, 14708971788.69237, 14875667326.732443, 15079098318.79331, 15144888989.667963, 15002658970.504765, 14349232841.34513, 11544480117.013124], 'frame_num': 54068199}\n" - ] - } - ], - "source": [ - "import random\n", - "\n", - "import numpy as np\n", - "import paddle\n", - "from paddle.io import DataLoader\n", - "from paddle.io import Dataset\n", - "\n", - "from deepspeech.frontend.audio import AudioSegment\n", - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.frontend.utility import read_manifest\n", - "\n", - "# https://github.com/PaddlePaddle/Paddle/pull/31481\n", - "class CollateFunc(object):\n", - " ''' Collate function for AudioDataset\n", - " '''\n", - " def __init__(self, feature_func):\n", - " self.feature_func = feature_func\n", - " \n", - " def __call__(self, batch):\n", - " mean_stat = None\n", - " var_stat = None\n", - " number = 0\n", - " for item in batch:\n", - " audioseg = AudioSegment.from_file(item['feat'])\n", - " feat = self.feature_func(audioseg) #(D, T)\n", - "\n", - " sums = np.sum(feat, axis=1)\n", - " if mean_stat is None:\n", - " mean_stat = sums\n", - " else:\n", - " mean_stat += sums\n", - "\n", - " square_sums = np.sum(np.square(feat), axis=1)\n", - " if var_stat is None:\n", - " var_stat = square_sums\n", - " else:\n", - " var_stat += square_sums\n", - "\n", - " number += feat.shape[1]\n", - " return number, mean_stat, var_stat\n", - "\n", - "\n", - "class AudioDataset(Dataset):\n", - " def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):\n", - " self._rng = rng if rng else np.random.RandomState(random_seed)\n", - " manifest = read_manifest(manifest_path)\n", - " if num_samples == -1:\n", - " sampled_manifest = manifest\n", - " else:\n", - " sampled_manifest = self._rng.choice(manifest, num_samples, replace=False)\n", - " self.items = sampled_manifest\n", - "\n", - " def __len__(self):\n", - " return len(self.items)\n", - "\n", - " def __getitem__(self, idx):\n", - " return self.items[idx]\n", - " \n", - " \n", - "augmentation_pipeline = AugmentationPipeline('{}')\n", - "audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=True,\n", - " target_dB=-20)\n", - "\n", - "def augment_and_featurize(audio_segment):\n", - " augmentation_pipeline.transform_audio(audio_segment)\n", - " return audio_featurizer.featurize(audio_segment)\n", - "\n", - "\n", - "collate_func = CollateFunc(augment_and_featurize)\n", - "\n", - "dataset = AudioDataset(\n", - " args.manifest_path,\n", - " args.num_samples)\n", - "\n", - "batch_size = 20\n", - "data_loader = DataLoader(\n", - " dataset,\n", - " batch_size=batch_size,\n", - " shuffle=False,\n", - " num_workers=args.num_workers,\n", - " collate_fn=collate_func)\n", - "\n", - "with paddle.no_grad():\n", - " all_mean_stat = None\n", - " all_var_stat = None\n", - " all_number = 0\n", - " wav_number = 0\n", - " for i, batch in enumerate(data_loader):\n", - " number, mean_stat, var_stat = batch\n", - " if i == 0:\n", - " all_mean_stat = mean_stat\n", - " all_var_stat = var_stat\n", - " else:\n", - " all_mean_stat += mean_stat\n", - " all_var_stat += var_stat\n", - " all_number += number\n", - " wav_number += batch_size\n", - "\n", - " if wav_number % 1000 == 0:\n", - " print('process {} wavs,{} frames'.format(wav_number,\n", - " all_number))\n", - "\n", - "cmvn_info = {\n", - " 'mean_stat': list(all_mean_stat.tolist()),\n", - " 'var_stat': list(all_var_stat.tolist()),\n", - " 'frame_num': all_number\n", - "}\n", - "print(cmvn_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "unlike-search", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/dataloader.ipynb b/.notebook/dataloader.ipynb deleted file mode 100644 index 3de8f64a9..000000000 --- a/.notebook/dataloader.ipynb +++ /dev/null @@ -1,389 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "emerging-meter", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " long_ = _make_signed(np.long)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " ulong = _make_unsigned(np.long)\n" - ] - } - ], - "source": [ - "import math\n", - "import random\n", - "import tarfile\n", - "import logging\n", - "import numpy as np\n", - "from collections import namedtuple\n", - "from functools import partial\n", - "\n", - "import paddle\n", - "from paddle.io import Dataset\n", - "from paddle.io import DataLoader\n", - "from paddle.io import BatchSampler\n", - "from paddle.io import DistributedBatchSampler\n", - "from paddle import distributed as dist\n", - "\n", - "from data_utils.utility import read_manifest\n", - "from data_utils.augmentor.augmentation import AugmentationPipeline\n", - "from data_utils.featurizer.speech_featurizer import SpeechFeaturizer\n", - "from data_utils.speech import SpeechSegment\n", - "from data_utils.normalizer import FeatureNormalizer\n", - "\n", - "\n", - "from data_utils.dataset import (\n", - " DeepSpeech2Dataset,\n", - " DeepSpeech2DistributedBatchSampler,\n", - " DeepSpeech2BatchSampler,\n", - " SpeechCollator,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "excessive-american", - "metadata": {}, - "outputs": [], - "source": [ - "def create_dataloader(manifest_path,\t\n", - " vocab_filepath,\t\n", - " mean_std_filepath,\t\n", - " augmentation_config='{}',\t\n", - " max_duration=float('inf'),\t\n", - " min_duration=0.0,\t\n", - " stride_ms=10.0,\t\n", - " window_ms=20.0,\t\n", - " max_freq=None,\t\n", - " specgram_type='linear',\t\n", - " use_dB_normalization=True,\t\n", - " random_seed=0,\t\n", - " keep_transcription_text=False,\t\n", - " is_training=False,\t\n", - " batch_size=1,\t\n", - " num_workers=0,\t\n", - " sortagrad=False,\t\n", - " shuffle_method=None,\t\n", - " dist=False):\t\n", - "\n", - " dataset = DeepSpeech2Dataset(\t\n", - " manifest_path,\t\n", - " vocab_filepath,\t\n", - " mean_std_filepath,\t\n", - " augmentation_config=augmentation_config,\t\n", - " max_duration=max_duration,\t\n", - " min_duration=min_duration,\t\n", - " stride_ms=stride_ms,\t\n", - " window_ms=window_ms,\t\n", - " max_freq=max_freq,\t\n", - " specgram_type=specgram_type,\t\n", - " use_dB_normalization=use_dB_normalization,\t\n", - " random_seed=random_seed,\t\n", - " keep_transcription_text=keep_transcription_text)\t\n", - "\n", - " if dist:\t\n", - " batch_sampler = DeepSpeech2DistributedBatchSampler(\t\n", - " dataset,\t\n", - " batch_size,\t\n", - " num_replicas=None,\t\n", - " rank=None,\t\n", - " shuffle=is_training,\t\n", - " drop_last=is_training,\t\n", - " sortagrad=is_training,\t\n", - " shuffle_method=shuffle_method)\t\n", - " else:\t\n", - " batch_sampler = DeepSpeech2BatchSampler(\t\n", - " dataset,\t\n", - " shuffle=is_training,\t\n", - " batch_size=batch_size,\t\n", - " drop_last=is_training,\t\n", - " sortagrad=is_training,\t\n", - " shuffle_method=shuffle_method)\t\n", - "\n", - " def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):\t\n", - " \"\"\"\t\n", - " Padding audio features with zeros to make them have the same shape (or\t\n", - " a user-defined shape) within one bach.\t\n", - "\n", - " If ``padding_to`` is -1, the maximun shape in the batch will be used\t\n", - " as the target shape for padding. Otherwise, `padding_to` will be the\t\n", - " target shape (only refers to the second axis).\t\n", - "\n", - " If `flatten` is True, features will be flatten to 1darray.\t\n", - " \"\"\"\t\n", - " new_batch = []\t\n", - " # get target shape\t\n", - " max_length = max([audio.shape[1] for audio, text in batch])\t\n", - " if padding_to != -1:\t\n", - " if padding_to < max_length:\t\n", - " raise ValueError(\"If padding_to is not -1, it should be larger \"\t\n", - " \"than any instance's shape in the batch\")\t\n", - " max_length = padding_to\t\n", - " max_text_length = max([len(text) for audio, text in batch])\t\n", - " # padding\t\n", - " padded_audios = []\t\n", - " audio_lens = []\t\n", - " texts, text_lens = [], []\t\n", - " for audio, text in batch:\t\n", - " padded_audio = np.zeros([audio.shape[0], max_length])\t\n", - " padded_audio[:, :audio.shape[1]] = audio\t\n", - " if flatten:\t\n", - " padded_audio = padded_audio.flatten()\t\n", - " padded_audios.append(padded_audio)\t\n", - " audio_lens.append(audio.shape[1])\t\n", - "\n", - " padded_text = np.zeros([max_text_length])\n", - " if is_training:\n", - " padded_text[:len(text)] = text\t# ids\n", - " else:\n", - " padded_text[:len(text)] = [ord(t) for t in text] # string\n", - " \n", - " texts.append(padded_text)\t\n", - " text_lens.append(len(text))\t\n", - "\n", - " padded_audios = np.array(padded_audios).astype('float32')\t\n", - " audio_lens = np.array(audio_lens).astype('int64')\t\n", - " texts = np.array(texts).astype('int32')\t\n", - " text_lens = np.array(text_lens).astype('int64')\t\n", - " return padded_audios, texts, audio_lens, text_lens\t\n", - "\n", - " loader = DataLoader(\t\n", - " dataset,\t\n", - " batch_sampler=batch_sampler,\t\n", - " collate_fn=partial(padding_batch, is_training=is_training),\t\n", - " num_workers=num_workers)\t\n", - " return loader" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "naval-brave", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('infer_manifest', str,\n", - " 'examples/aishell/data/manifest.dev',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " 'examples/aishell/data/mean_std.npz',\n", - " \"Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/aishell/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/aishell/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'linear',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc'])\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "bearing-physics", - "metadata": {}, - "outputs": [], - "source": [ - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " augmentation_config='{}',\n", - " #max_duration=float('inf'),\n", - " max_duration=27.0,\n", - " min_duration=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=True,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "classified-melissa", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[22823, 26102, 20195, 37324, 0 , 0 ],\n", - " [22238, 26469, 23601, 22909, 0 , 0 ],\n", - " [20108, 26376, 22235, 26085, 0 , 0 ],\n", - " [36824, 35201, 20445, 25345, 32654, 24863],\n", - " [29042, 27748, 21463, 23456, 0 , 0 ]])\n", - "test raw 大时代里\n", - "test raw 煲汤受宠\n", - "audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [163, 167, 180, 186, 186])\n", - "test len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [4, 4, 4, 6, 4])\n", - "audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n", - " [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n", - " [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n", - " [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n", - " [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n", - " [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n", - " [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n", - " [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n", - " [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n", - " [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n", - " [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n", - " [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n", - " [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n", - " [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n", - " [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n", - " ...,\n", - " [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n", - " [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n", - " [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n", - "\n", - " [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n", - " [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n", - " [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n", - " ...,\n", - " [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n", - " [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n", - " [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n" - ] - } - ], - "source": [ - "for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test', text)\n", - " print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", - " print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", - " print('audio len', audio_len)\n", - " print('test len', text_len)\n", - " print('audio', audio)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "unexpected-skating", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "minus-modern", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file diff --git a/.notebook/dataloader_with_tokens_tokenids.ipynb b/.notebook/dataloader_with_tokens_tokenids.ipynb deleted file mode 100644 index 7d93dd009..000000000 --- a/.notebook/dataloader_with_tokens_tokenids.ipynb +++ /dev/null @@ -1,1204 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "medieval-monday", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/DeepSpeech-2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "emerging-meter", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n" - ] - } - ], - "source": [ - "import math\n", - "import random\n", - "import tarfile\n", - "import logging\n", - "import numpy as np\n", - "from collections import namedtuple\n", - "from functools import partial\n", - "\n", - "import paddle\n", - "from paddle.io import Dataset\n", - "from paddle.io import DataLoader\n", - "from paddle.io import BatchSampler\n", - "from paddle.io import DistributedBatchSampler\n", - "from paddle import distributed as dist\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "excessive-american", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "naval-brave", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:93] register user softmax to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:97] register user log_softmax to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:101] register user sigmoid to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:105] register user log_sigmoid to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:109] register user relu to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:119] override cat of paddle if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:133] override item of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:144] override long of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:164] override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:179] override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:185] override eq of paddle if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:195] override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:212] override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:223] register user view to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:233] register user view_as to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:259] register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:277] register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:288] register user fill_ to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:298] register user repeat to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:303] register user softmax to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:308] register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:312] register user relu to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:322] register user type_as to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:337] register user to to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:346] register user float to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:356] register user tolist to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:371] register user glu to paddle.nn.functional, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:422] override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:428] register user Module to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:434] register user ModuleList to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:450] register user GLU to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:483] register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:489] register user export to paddle.jit, remove this when fixed!\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/tiny/s1/data/spm_bpe', 'infer_manifest': 'examples/tiny/s1/data/manifest.tiny', 'mean_std_path': 'examples/tiny/s1/data/mean_std.npz', 'vocab_path': 'examples/tiny/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/tiny/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('unit_type', str,\n", - " 'char',\n", - " \"Options: char, word, spm.\",\n", - " choices=['char', 'word', 'spm'])\n", - "add_arg('spm_model_prefix', str,\n", - " 'examples/tiny/s1/data/spm_bpe',\n", - " \"spm model prefix.\",)\n", - "add_arg('infer_manifest', str,\n", - " 'examples/tiny/s1/data/manifest.tiny',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " 'examples/tiny/s1/data/mean_std.npz',\n", - " \"Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/tiny/s1/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/tiny/s1/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc'])\n", - "add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n", - "add_arg('delta_delta', bool, False, \"delta delta\")\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "wired-principal", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/aishell/s1/data/spm_bpe', 'infer_manifest': 'examples/aishell/s1/data/manifest.test', 'mean_std_path': '', 'vocab_path': 'examples/aishell/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('unit_type', str,\n", - " 'char',\n", - " \"Options: char, word, spm.\",\n", - " choices=['char', 'word', 'spm'])\n", - "add_arg('spm_model_prefix', str,\n", - " 'examples/aishell/s1/data/spm_bpe',\n", - " \"spm model prefix.\",)\n", - "add_arg('infer_manifest', str,\n", - " 'examples/aishell/s1/data/manifest.test',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " '',\n", - " \"examples/aishell/s1/data/mean_std.npz, Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/aishell/s1/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/aishell/s1/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc', 'fbank'])\n", - "add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n", - "add_arg('delta_delta', bool, False, \"delta delta\")\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "bearing-physics", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " long_ = _make_signed(np.long)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " ulong = _make_unsigned(np.long)\n" - ] - } - ], - "source": [ - "from deepspeech.frontend.utility import read_manifest\n", - "from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n", - "from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n", - "from deepspeech.frontend.speech import SpeechSegment\n", - "from deepspeech.frontend.normalizer import FeatureNormalizer\n", - "\n", - "\n", - "from deepspeech.io.collator import SpeechCollator\n", - "from deepspeech.io.dataset import ManifestDataset\n", - "from deepspeech.io.sampler import (\n", - " SortagradDistributedBatchSampler,\n", - " SortagradBatchSampler,\n", - ")\n", - "from deepspeech.io import create_dataloader\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " unit_type=args.unit_type,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " spm_model_prefix=args.spm_model_prefix,\n", - " augmentation_config='{}',\n", - " max_input_len=27.0,\n", - " min_input_len=0.0,\n", - " max_output_len=float('inf'),\n", - " min_output_len=0.0,\n", - " max_output_input_ratio=float('inf'),\n", - " min_output_input_ratio=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=True,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " num_workers=0,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "classified-melissa", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "fbank\n", - "[232 387 331 ... 249 249 262] int16\n", - "fbank\n", - "[-138 -219 -192 ... 338 324 351] int16\n", - "fbank\n", - "[ 694 1175 1022 ... 553 514 627] int16\n", - "fbank\n", - "[-39 -79 -53 ... 139 172 99] int16\n", - "fbank\n", - "[-277 -480 -425 ... 758 767 739] int16\n", - "fbank\n", - "[ 399 693 609 ... 1291 1270 1291] int16\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " if arr.dtype == np.object:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "fbank\n", - "[ -750 -1254 -1107 ... 2276 1889 2067] int16\n", - "fbank\n", - "[ -127 -199 -149 ... -5243 -5065 -5398] int16\n", - "fbank\n", - "[ 465 783 677 ... 980 903 1008] int16\n", - "fbank\n", - "[ 90 160 157 ... -2 -16 -21] int16\n", - "fbank\n", - "[ 213 345 295 ... 2483 2246 2501] int16\n", - "fbank\n", - "[ -86 -159 -131 ... 270 258 290] int16\n", - "fbank\n", - "[-1023 -1714 -1505 ... 1532 1596 1575] int16\n", - "fbank\n", - "[-366 -602 -527 ... 374 370 379] int16\n", - "fbank\n", - "[ 761 1275 1127 ... 369 413 295] int16\n", - "fbank\n", - "[382 621 550 ... 161 161 174] int16\n", - "fbank\n", - "[ -28 -91 -120 ... 28 34 11] int16\n", - "fbank\n", - "[ -5 -5 -5 ... 268 294 341] int16\n", - "fbank\n", - "[240 417 684 ... 267 262 219] int16\n", - "fbank\n", - "[131 206 194 ... 383 320 343] int16\n", - "test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[31069, 21487, 29233, 30340, 20320, -1 , -1 ],\n", - " [20540, 24471, 19968, 25552, 30340, 26159, -1 ],\n", - " [36825, 20010, 31243, 24230, 26159, 32654, 30340],\n", - " [20108, 21040, 20108, -1 , -1 , -1 , -1 ],\n", - " [21435, 34892, 25919, 21270, -1 , -1 , -1 ]])\n", - "fbank\n", - "[1155 1890 1577 ... 1092 989 1130] int16\n", - "fbank\n", - "[296 358 296 ... 140 140 168] int16\n", - "fbank\n", - "[-50 -91 -63 ... 104 104 86] int16\n", - "fbank\n", - "[-37 -66 -50 ... -31 -45 -52] int16\n", - "fbank\n", - "[-401 -652 -547 ... -339 -307 -344] int16\n", - "fbank\n", - "[-21 -47 -51 ... 94 81 107] int16\n", - "fbank\n", - "[ 533 887 755 ... 3074 2853 3254] int16\n", - "fbank\n", - "[ 44 71 66 ... -628 -733 -601] int16\n", - "fbank\n", - "[ 50 86 79 ... 129 116 138] int16\n", - "fbank\n", - "[ 92 146 126 ... -208 -193 -179] int16\n", - "test raw: 祝可爱的你\n", - "test raw: 去行政化\n", - "audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [184, 194, 196, 204, 207])\n", - "test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [5, 6, 7, 3, 4])\n", - "audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[12.25633812, 12.61639309, 10.36936474, ..., 13.02949619, 11.51365757, 10.59789085],\n", - " [13.32148266, 13.41071606, 11.43800735, ..., 13.69783783, 12.83939362, 11.51259613],\n", - " [12.62640572, 12.53621101, 10.97212505, ..., 13.33757591, 12.32293034, 10.75493717],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[10.99619484, 11.35202599, 9.56922054 , ..., 9.94971657 , 9.88354111 , 9.55315971 ],\n", - " [10.44461155, 9.81688595 , 5.62538481 , ..., 10.60468388, 10.94417381, 9.42646980 ],\n", - " [10.23835754, 10.23407459, 7.99464273 , ..., 10.68097591, 9.91640091 , 10.04131031],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[14.10299397, 14.50298119, 12.87738323, ..., 12.62796497, 12.69949627, 11.43171215],\n", - " [13.85035992, 13.15289116, 10.66541386, ..., 13.34364223, 13.46972179, 11.02160740],\n", - " [13.19866467, 13.23537827, 11.65760899, ..., 12.72559357, 12.42716217, 11.74562359],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.85668373, 12.82431412, 11.68144703, ..., 14.10119247, 15.12791920, 13.68221378],\n", - " [13.19507027, 13.40244961, 11.43618393, ..., 13.32919979, 13.68267441, 12.73429012],\n", - " [13.02173328, 12.92082500, 11.44303989, ..., 12.77793121, 13.10915661, 11.77327728],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.90771198, 13.40234852, 13.01435471, ..., 13.80359459, 14.08088684, 13.17883396],\n", - " [14.06678009, 14.06943512, 12.52837276, ..., 13.66423225, 13.66300583, 13.60142994],\n", - " [12.58743191, 12.94520760, 11.75190544, ..., 14.28828907, 14.08229160, 13.02433395],\n", - " ...,\n", - " [16.20896912, 16.42283821, 14.94358730, ..., 12.91146755, 12.66766262, 11.76361752],\n", - " [13.49324894, 14.14653301, 13.16490936, ..., 13.23435783, 13.45378494, 12.60386276],\n", - " [15.56288910, 15.92445087, 14.90794277, ..., 13.43840790, 13.41075516, 12.55605984]]])\n" - ] - } - ], - "source": [ - "for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test:', text)\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", - " print('audio len:', audio_len)\n", - " print('test len:', text_len)\n", - " print('audio:', audio)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "unexpected-skating", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "minus-modern", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "fbank\n", - "[232 387 331 ... 249 249 262] int16\n", - "fbank\n", - "[-138 -219 -192 ... 338 324 351] int16\n", - "fbank\n", - "[ 694 1175 1022 ... 553 514 627] int16\n", - "fbank\n", - "[-39 -79 -53 ... 139 172 99] int16\n", - "fbank\n", - "[-277 -480 -425 ... 758 767 739] int16\n", - "fbank\n", - "test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[2695, 505, 2332, 2553, 169, -1 , -1 ],\n", - " [ 230, 1237, 2 , 1556, 2553, 1694, -1 ],\n", - " [3703, 28 , 2739, 1172, 1694, 2966, 2553],\n", - " [ 70 , 355, 70 , -1 , -1 , -1 , -1 ],\n", - " [ 477, 3363, 1621, 412, -1 , -1 , -1 ]])\n", - "[ 399 693 609 ... 1291 1270 1291] int16\n", - "test raw: ઇǹज৹©\n", - "test raw: ǝണٕƜ\n", - "test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [5, 6, 7, 3, 4])\n", - "audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[12.25794601, 12.61855793, 10.37306023, ..., 13.12571049, 11.53678799, 10.32210350],\n", - " [13.32333183, 13.41336918, 11.44248962, ..., 13.65861225, 12.79308128, 11.31168747],\n", - " [12.62584686, 12.53506088, 10.96861362, ..., 13.32526493, 12.41560936, 10.71458912],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[11.00003052, 11.35529137, 9.56384087 , ..., 10.06063652, 10.16322994, 9.43149185 ],\n", - " [10.44556236, 9.81155300 , 5.49400425 , ..., 10.84116268, 11.02734756, 9.42253590 ],\n", - " [10.23620510, 10.23321152, 7.99466419 , ..., 10.93381882, 10.28395081, 10.00841141],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[14.10379314, 14.50375748, 12.87825108, ..., 12.68065739, 12.62359715, 11.53773308],\n", - " [13.84964657, 13.15079498, 10.67198086, ..., 13.24875164, 13.45796680, 10.97363472],\n", - " [13.19808197, 13.23482990, 11.65900230, ..., 12.70375061, 12.41395664, 11.88668156],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.85676289, 12.82410812, 11.67961884, ..., 14.12018299, 15.14850044, 13.80065727],\n", - " [13.19532776, 13.40243340, 11.43492508, ..., 13.29144669, 13.70278549, 12.67841339],\n", - " [13.02196407, 12.92111111, 11.43998623, ..., 12.71165752, 13.16518497, 11.92028046],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.90661621, 13.40162563, 13.01394463, ..., 13.84056377, 14.11240959, 13.21227264],\n", - " [14.06642914, 14.06922340, 12.52955723, ..., 13.55829811, 13.60157204, 13.50268650],\n", - " [12.58881378, 12.94780254, 11.75758171, ..., 14.29055786, 14.12165928, 13.02695847],\n", - " ...,\n", - " [16.20891571, 16.42290306, 14.94398117, ..., 12.86083794, 12.63515949, 11.67581463],\n", - " [13.49345875, 14.14656067, 13.16498375, ..., 13.28024578, 13.40956783, 12.70357513],\n", - " [15.56265163, 15.92387581, 14.90643024, ..., 13.45694065, 13.44703197, 12.81099033]]])\n", - "audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [184, 194, 196, 204, 207])\n" - ] - } - ], - "source": [ - "keep_transcription_text=False\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " unit_type=args.unit_type,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " spm_model_prefix=args.spm_model_prefix,\n", - " augmentation_config='{}',\n", - " max_input_len=27.0,\n", - " min_input_len=0.0,\n", - " max_output_len=float('inf'),\n", - " min_output_len=0.0,\n", - " max_output_input_ratio=float('inf'),\n", - " min_output_input_ratio=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=keep_transcription_text,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " num_workers=0,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)\n", - "for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test:', text)\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", - " print('test len:', text_len)\n", - " print('audio:', audio)\n", - " print('audio len:', audio_len)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "competitive-mounting", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "knowing-military", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 1, 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False, 'stride_ms': 10.0, 'window_ms': 25.0, 'sample_rate': 16000, 'manifest_path': 'examples/aishell/s1/data/manifest.train', 'output_path': 'examples/aishell/s1/data/mean_std.npz'}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "\n", - "add_arg('num_samples', int, 1, \"# of samples to for statistics.\")\n", - "add_arg('specgram_type', str, 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc, fbank.\",\n", - " choices=['linear', 'mfcc', 'fbank'])\n", - "add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n", - "add_arg('delta_delta', bool, False,\"Audio feature with delta delta.\")\n", - "add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n", - "add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n", - "add_arg('sample_rate', int, 16000, \"target sample rate.\")\n", - "add_arg('manifest_path', str,\n", - " 'examples/aishell/s1/data/manifest.train',\n", - " \"Filepath of manifest to compute normalizer's mean and stddev.\")\n", - "add_arg('output_path', str,\n", - " 'examples/aishell/s1/data/mean_std.npz',\n", - " \"Filepath of write mean and stddev to (.npz).\")\n", - "args = parser.parse_args([])\n", - "print(vars(args))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "unnecessary-province", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n", - "from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n", - "from deepspeech.frontend.normalizer import FeatureNormalizer\n", - "from deepspeech.frontend.audio import AudioSegment\n", - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.frontend.utility import read_manifest\n", - "\n", - "\n", - "\n", - "def mean(args):\n", - " augmentation_pipeline = AugmentationPipeline('{}')\n", - " audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=True,\n", - " target_dB=-20,\n", - " dither=0.0)\n", - "\n", - " def augment_and_featurize(audio_segment):\n", - " augmentation_pipeline.transform_audio(audio_segment)\n", - " return audio_featurizer.featurize(audio_segment)\n", - "\n", - " normalizer = FeatureNormalizer(\n", - " mean_std_filepath=None,\n", - " manifest_path=args.manifest_path,\n", - " featurize_func=augment_and_featurize,\n", - " num_samples=args.num_samples)\n", - " normalizer.write_to_file(args.output_path)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "interested-camping", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n", - "[54. 90. 77. ... 58. 58. 61.]\n", - "29746\n", - "fbank\n", - "[54 90 77 ... 58 58 61] int16\n", - "(184, 80) float64\n", - "[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n", - " 7.80671114]\n", - " [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n", - " 8.83860079]\n", - " [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n", - " 7.96168491]\n", - " ...\n", - " [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n", - " 8.75616521]\n", - " [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n", - " 7.69287184]\n", - " [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n", - " 8.16007108]]\n", - "(184, 80) float64\n", - "[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n", - " 7.80671114]\n", - " [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n", - " 8.83860079]\n", - " [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n", - " 7.96168491]\n", - " ...\n", - " [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n", - " 8.75616521]\n", - " [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n", - " 7.69287184]\n", - " [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n", - " 8.16007108]]\n" - ] - } - ], - "source": [ - "wav='/workspace/DeepSpeech-2.x/examples/aishell/s1/../../..//examples/dataset/aishell/data_aishell/wav/test/S0916/BAC009S0916W0426.wav'\n", - "test='祝可爱的你'\n", - "audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=False,\n", - " target_dB=-20,\n", - " dither=0.0)\n", - "samples = AudioSegment.from_file(wav)\n", - "print(samples._samples)\n", - "print(samples._samples * 2**15)\n", - "print(len(samples._samples))\n", - "feat = audio_featurizer.featurize(samples, False, False)\n", - "feat = feat.T\n", - "print(feat.shape, feat.dtype)\n", - "print(feat)\n", - "\n", - "from python_speech_features import logfbank\n", - "max_freq = args.sample_rate / 2\n", - "fbank_feat = logfbank(\n", - " signal=samples.to('int16'),\n", - " samplerate=args.sample_rate,\n", - " winlen=0.001 * args.window_ms,\n", - " winstep=0.001 * args.stride_ms,\n", - " nfilt=args.feat_dim,\n", - " nfft=512,\n", - " lowfreq=20,\n", - " highfreq=max_freq,\n", - " preemph=0.97,\n", - " dither=0.0,\n", - " wintype='povey')\n", - "print(fbank_feat.shape, fbank_feat.dtype)\n", - "print(fbank_feat)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "numeric-analyst", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(184, 160)\n", - "[ 8.59522397 8.43148278 8.36414052 8.45487173 8.31761643 8.04843683\n", - " 8.01683696 7.6574614 7.95521932 8.22945157 10.20138275 9.0447775\n", - " 9.14763398 9.18184349 9.03801065 9.04852307 8.67706728 8.71894271\n", - " 9.54553655 9.19535135 8.76413076 8.47828946 8.52586143 8.49469288\n", - " 8.72461247 8.28562879 8.11581393 7.99922156 7.91023364 8.04142296\n", - " 7.89762773 7.76257636 8.32043745 8.01592886 8.34109665 8.90115454\n", - " 8.48246945 7.98658664 8.05745122 8.11384088 8.18864479 8.8091827\n", - " 11.8067711 13.25258218 14.44311795 13.90515283 14.00120623 13.99801252\n", - " 13.81595394 13.6379904 13.3574897 13.14933334 12.96518543 13.02601156\n", - " 12.70246737 12.54410834 12.15615068 11.86574681 11.67497882 10.79645481\n", - " 10.48150035 10.03758575 10.05637027 9.92891308 10.06923218 12.43382431\n", - " 12.71428321 14.33135052 13.94470959 14.29188291 14.11483993 14.03496606\n", - " 13.78167331 13.66701466 14.40308625 14.73934137 15.09569382 14.89565815\n", - " 15.10519995 14.94383582 15.03275563 15.42194679 15.29219967 15.41602274\n", - " 15.39242545 15.76836177 16.259222 16.47777231 17.03366795 17.46165793\n", - " 17.52596217 17.78844031 17.99878075 18.11446843 17.95761578 17.99900337\n", - " 17.86282737 17.7290163 17.47686504 17.43425516 17.07750485 16.64395242\n", - " 15.68217043 14.90058399 14.45645737 14.0405463 14.89549542 16.00405781\n", - " 16.27301689 16.37572895 16.31219037 16.31765447 16.44819716 16.36281089\n", - " 16.24932823 15.79302555 14.76361963 13.95761882 13.48917053 13.45543501\n", - " 13.00091327 13.13854248 13.74596395 13.86340629 14.00656109 13.77432101\n", - " 13.64267001 13.35742634 13.23042234 12.97916104 12.80694468 12.70005006\n", - " 13.2802483 13.22644525 13.14579624 13.02536594 13.36511022 11.37167205\n", - " 12.11598045 12.47619798 12.83885973 11.63880287 11.42083924 11.08747705\n", - " 11.04093403 11.11263149 10.74353319 10.58734669 10.46180738 10.34157335\n", - " 9.63131146 9.70582692 9.29059204 8.94583657 8.66065094 8.46799095\n", - " 8.25064103 8.30239167 8.19463371 8.12104567 8.02731234 8.06412715\n", - " 7.84889951 7.73090283 7.74119562 7.85444657 7.80717312 7.7129933\n", - " 7.84087442 7.77907788 7.60660865 7.55051479 7.458385 7.496416\n", - " 7.69519793 7.49086759 7.32199493 8.01617458 7.58525375 7.06661122\n", - " 6.94653756 7.19874283 7.28515661 7.17574078]\n", - "(184,)\n", - "(184,)\n", - "[1.48370471 1.52174523 1.46984238 1.67010478 1.88757689 1.68825992\n", - " 1.74270259 1.55497318 1.29200818 1.68446481 1.88133219 1.97138928\n", - " 2.15910096 2.3149476 1.9820247 2.07694378 1.93498835 2.01493974\n", - " 2.39156824 2.02396518 1.69586449 1.63808752 1.64020228 1.43573473\n", - " 1.93092656 1.37466294 1.34704929 1.59600739 1.03960441 1.45276496\n", - " 1.59360131 1.57466343 1.89491479 1.79333746 1.32701974 1.49441767\n", - " 1.51466756 1.63497989 1.42858074 1.51135396 1.61077201 1.81066387\n", - " 1.83367783 2.3507094 2.87885378 3.26231227 2.1313117 1.98557548\n", - " 1.99105426 2.26150533 2.34298751 2.44621608 2.39201042 2.41226503\n", - " 2.5142992 3.03777565 2.81592295 2.75117863 2.78324175 2.68819666\n", - " 2.8945782 2.84464168 2.680973 2.78397395 2.47996808 1.71829563\n", - " 1.60636949 1.65992483 1.38122631 1.74831825 2.16006884 1.68076185\n", - " 1.69329487 1.44929837 1.63763312 1.80101076 2.01166253 2.03254244\n", - " 1.9583913 2.04542255 2.00859694 2.16600883 2.16095629 1.97541122\n", - " 2.13807632 2.06386436 2.2154187 2.84205688 2.54862449 2.64321545\n", - " 2.6805773 2.52300146 2.53209001 2.54682059 2.4521937 2.43155532\n", - " 2.42571275 2.23421289 2.23164529 2.23597192 2.14215121 2.10406703\n", - " 2.07962874 1.88506161 1.80092372 1.61156092 1.77426835 1.98765563\n", - " 2.0356793 1.87964187 1.779513 1.87187681 1.76463632 1.70978684\n", - " 1.76471778 1.75604749 1.62792552 1.73929352 1.6887024 1.8677704\n", - " 2.17342368 2.08166072 2.14567453 2.15936953 2.18351006 2.41010388\n", - " 2.26101752 2.25468001 2.23739715 2.15395133 2.04547813 1.92038843\n", - " 1.85491264 1.91905927 2.16709365 1.99924152 2.1850471 2.55461622\n", - " 2.72476673 1.69682926 1.73249614 2.06992695 2.1210591 1.66854454\n", - " 1.63907505 1.32203822 1.38992558 1.2436937 1.17932877 1.02963653\n", - " 1.26085036 1.16997132 1.09339504 1.14188689 1.18675772 1.31859788\n", - " 1.21746591 1.3872131 1.26095274 1.34885761 1.46633543 1.64506975\n", - " 1.36013821 1.45574721 1.43766588 1.65119054 1.57163772 1.55082968\n", - " 1.29413316 1.38351736 1.64234673 1.57186432 1.45381083 1.71204761\n", - " 1.51828607 1.30639985 1.32928395 1.49004237 1.6057589 1.81815735\n", - " 1.67784678 1.72180861 1.60703743 1.64850255]\n" - ] - } - ], - "source": [ - "a = np.hstack([feat, feat])\n", - "print(a.shape)\n", - "m = np.mean(a, axis=1)\n", - "print(m)\n", - "print(m.shape)\n", - "std = np.std(a, axis=1)\n", - "print(std.shape)\n", - "print(std)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "nonprofit-potato", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "hispanic-ethics", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torchaudio\n", - "import torchaudio.compliance.kaldi as kaldi\n", - "import torchaudio.sox_effects as sox_effects\n", - "from torch.nn.utils.rnn import pad_sequence\n", - "torchaudio.set_audio_backend(\"sox\")" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "changing-calvin", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 29746])\n", - "tensor([[54., 90., 77., ..., 58., 58., 61.]])\n", - "(184, 80)\n", - "[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n", - " [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n", - " [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n", - " ...\n", - " [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n", - " [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n", - " [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n", - "-----------\n", - "[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n", - "(184, 80)\n", - "[[-10.177039 -10.717326 -15.46954 ... -10.546229 -11.897424 -12.987689]\n", - " [ -9.750411 -10.476343 -14.485752 ... -9.557108 -10.436023 -11.955799]\n", - " [-10.525113 -10.798049 -13.46475 ... -10.343097 -11.101464 -12.832712]\n", - " ...\n", - " [-10.649446 -10.907673 -14.056403 ... -10.578607 -11.790988 -12.038239]\n", - " [-10.816959 -11.114918 -12.88781 ... -10.570049 -11.199847 -13.101528]\n", - " [-14.320845 -13.03106 -13.036756 ... -10.829194 -11.171779 -12.634331]]\n", - "**************\n", - "[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n", - "[54. 90. 77. ... 58. 58. 61.] float32\n", - "(184, 80)\n", - "[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n", - " [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n", - " [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n", - " ...\n", - " [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n", - " [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n", - " [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: UserWarning: torchaudio.backend.sox_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use \"torchaudio.load\".\n", - " \"\"\"Entry point for launching an IPython kernel.\n" - ] - } - ], - "source": [ - "waveform, sample_rate = torchaudio.load_wav(wav)\n", - "print(waveform.shape)\n", - "print(waveform)\n", - "mat = kaldi.fbank(\n", - " waveform,\n", - " num_mel_bins=80,\n", - " frame_length=25,\n", - " frame_shift=10,\n", - " dither=0,\n", - " energy_floor=0.0,\n", - " sample_frequency=sample_rate\n", - " )\n", - "mat = mat.detach().numpy()\n", - "print(mat.shape)\n", - "print(mat)\n", - "\n", - "print('-----------')\n", - "print(samples._samples)\n", - "aud = torch.tensor(samples._samples).view(1, -1)\n", - "mat = kaldi.fbank(\n", - " aud,\n", - " num_mel_bins=80,\n", - " frame_length=25,\n", - " frame_shift=10,\n", - " dither=0,\n", - " energy_floor=0.0,\n", - " sample_frequency=sample_rate\n", - " )\n", - "mat = mat.detach().numpy()\n", - "print(mat.shape)\n", - "print(mat)\n", - "\n", - "print('**************')\n", - "print(samples._samples)\n", - "tmp = samples.to('int16').astype('float32')\n", - "print(tmp, tmp.dtype)\n", - "aud = torch.tensor(tmp).view(1, -1)\n", - "mat = kaldi.fbank(\n", - " aud,\n", - " num_mel_bins=80,\n", - " frame_length=25,\n", - " frame_shift=10,\n", - " dither=0,\n", - " energy_floor=0.0,\n", - " sample_frequency=sample_rate\n", - " )\n", - "mat = mat.detach().numpy()\n", - "print(mat.shape)\n", - "print(mat)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "buried-dependence", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "silver-printing", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "outer-space", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(29746,)\n", - "[54 90 77 ... 58 58 61]\n", - "(184, 80)\n", - "[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n", - " 7.80671114]\n", - " [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n", - " 8.83860079]\n", - " [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n", - " 7.96168491]\n", - " ...\n", - " [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n", - " 8.75616521]\n", - " [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n", - " 7.69287184]\n", - " [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n", - " 8.16007108]]\n", - "(184, 13)\n", - "[[ 14.73775998 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n", - " 8.86862748]\n", - " [ 15.31274834 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n", - " 6.78353115]\n", - " [ 13.82218765 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n", - " -0.05919222]\n", - " ...\n", - " [ 13.5837844 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n", - " 5.59045842]\n", - " [ 13.75757034 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n", - " 12.0785146 ]\n", - " [ 11.92813809 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n", - " 0.88337965]]\n" - ] - } - ], - "source": [ - "from python_speech_features import mfcc\n", - "from python_speech_features import delta\n", - "from python_speech_features import logfbank\n", - "import scipy.io.wavfile as iowav\n", - "\n", - "(rate,sig) = iowav.read(wav)\n", - "print(sig.shape)\n", - "print(sig)\n", - "\n", - "# note that generally nfilt=40 is used for speech recognition\n", - "fbank_feat = logfbank(sig,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n", - "print(fbank_feat.shape)\n", - "print(fbank_feat)\n", - "\n", - "# the computed fbank coefficents of english.wav with dimension [110,23]\n", - "# [ 12.2865\t12.6906\t13.1765\t15.714\t16.064\t15.7553\t16.5746\t16.9205\t16.6472\t16.1302\t16.4576\t16.7326\t16.8864\t17.7215\t18.88\t19.1377\t19.1495\t18.6683\t18.3886\t20.3506\t20.2772\t18.8248\t18.1899\n", - "# 11.9198\t13.146\t14.7215\t15.8642\t17.4288\t16.394\t16.8238\t16.1095\t16.4297\t16.6331\t16.3163\t16.5093\t17.4981\t18.3429\t19.6555\t19.6263\t19.8435\t19.0534\t19.001\t20.0287\t19.7707\t19.5852\t19.1112\n", - "# ...\n", - "# ...\n", - "# the same with that using kaldi commands: compute-fbank-feats --dither=0.0\n", - "\n", - "mfcc_feat = mfcc(sig,dither=0,useEnergy=True,wintype='povey')\n", - "print(mfcc_feat.shape)\n", - "print(mfcc_feat)\n", - "\n", - "# the computed mfcc coefficents of english.wav with dimension [110,13]\n", - "# [ 17.1337\t-23.3651\t-7.41751\t-7.73686\t-21.3682\t-8.93884\t-3.70843\t4.68346\t-16.0676\t12.782\t-7.24054\t8.25089\t10.7292\n", - "# 17.1692\t-23.3028\t-5.61872\t-4.0075\t-23.287\t-20.6101\t-5.51584\t-6.15273\t-14.4333\t8.13052\t-0.0345329\t2.06274\t-0.564298\n", - "# ...\n", - "# ...\n", - "# the same with that using kaldi commands: compute-mfcc-feats --dither=0.0" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "sporting-school", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(184, 80)\n", - "[[-10.17703627 -10.71732606 -15.46954014 ... -10.54623152 -11.89742148\n", - " -12.98770428]\n", - " [ -9.75040771 -10.47634331 -14.48575413 ... -9.55710616 -10.43602673\n", - " -11.95581463]\n", - " [-10.52510987 -10.79804975 -13.46475161 ... -10.34309947 -11.10146239\n", - " -12.83273051]\n", - " ...\n", - " [-10.64944197 -10.90767335 -14.05640404 ... -10.57860915 -11.7909807\n", - " -12.03825021]\n", - " [-10.8169558 -11.11491806 -12.88781116 ... -10.57004889 -11.19985048\n", - " -13.10154358]\n", - " [-14.32084168 -13.03106051 -13.03675699 ... -10.82919465 -11.17177892\n", - " -12.63434434]]\n", - "(184, 13)\n", - "[[ -6.05665544 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n", - " 8.86862748]\n", - " [ -5.48166707 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n", - " 6.78353115]\n", - " [ -6.97222776 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n", - " -0.05919222]\n", - " ...\n", - " [ -7.21063102 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n", - " 5.59045842]\n", - " [ -7.03684508 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n", - " 12.0785146 ]\n", - " [ -8.86627732 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n", - " 0.88337965]]\n" - ] - } - ], - "source": [ - "fbank_feat = logfbank(samples._samples,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n", - "print(fbank_feat.shape)\n", - "print(fbank_feat)\n", - "\n", - "mfcc_feat = mfcc(samples._samples,dither=0,useEnergy=True,wintype='povey')\n", - "print(mfcc_feat.shape)\n", - "print(mfcc_feat)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "restricted-license", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "specialized-threat", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/espnet_dataloader.ipynb b/.notebook/espnet_dataloader.ipynb deleted file mode 100644 index 1bfc13e3c..000000000 --- a/.notebook/espnet_dataloader.ipynb +++ /dev/null @@ -1,1541 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 147, - "id": "extensive-venice", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/\n" - ] - }, - { - "data": { - "text/plain": [ - "'/'" - ] - }, - "execution_count": 147, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 148, - "id": "correct-window", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "manifest.dev\t manifest.test-clean\t manifest.train\r\n", - "manifest.dev.raw manifest.test-clean.raw manifest.train.raw\r\n" - ] - } - ], - "source": [ - "!ls /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/" - ] - }, - { - "cell_type": "code", - "execution_count": 149, - "id": "exceptional-cheese", - "metadata": {}, - "outputs": [], - "source": [ - "dev_data='/workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev'" - ] - }, - { - "cell_type": "code", - "execution_count": 150, - "id": "extraordinary-orleans", - "metadata": {}, - "outputs": [], - "source": [ - "from deepspeech.frontend.utility import read_manifest" - ] - }, - { - "cell_type": "code", - "execution_count": 151, - "id": "returning-lighter", - "metadata": {}, - "outputs": [], - "source": [ - "dev_json = read_manifest(dev_data)" - ] - }, - { - "cell_type": "code", - "execution_count": 152, - "id": "western-founder", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'input': [{'feat': '/workspace/zhanghui/asr/espnet/egs/librispeech/asr1/dump/dev/deltafalse/feats.1.ark:16',\n", - " 'name': 'input1',\n", - " 'shape': [1063, 83]}],\n", - " 'output': [{'name': 'target1',\n", - " 'shape': [41, 5002],\n", - " 'text': 'AS I APPROACHED THE CITY I HEARD BELLS RINGING AND A '\n", - " 'LITTLE LATER I FOUND THE STREETS ASTIR WITH THRONGS OF '\n", - " 'WELL DRESSED PEOPLE IN FAMILY GROUPS WENDING THEIR WAY '\n", - " 'HITHER AND THITHER',\n", - " 'token': '▁AS ▁I ▁APPROACHED ▁THE ▁CITY ▁I ▁HEARD ▁BELL S ▁RING '\n", - " 'ING ▁AND ▁A ▁LITTLE ▁LATER ▁I ▁FOUND ▁THE ▁STREETS ▁AS '\n", - " 'T IR ▁WITH ▁THRONG S ▁OF ▁WELL ▁DRESSED ▁PEOPLE ▁IN '\n", - " '▁FAMILY ▁GROUP S ▁WE ND ING ▁THEIR ▁WAY ▁HITHER ▁AND '\n", - " '▁THITHER',\n", - " 'tokenid': '713 2458 676 4502 1155 2458 2351 849 389 3831 206 627 '\n", - " '482 2812 2728 2458 2104 4502 4316 713 404 212 4925 '\n", - " '4549 389 3204 4861 1677 3339 2495 1950 2279 389 4845 '\n", - " '302 206 4504 4843 2394 627 4526'}],\n", - " 'utt': '116-288045-0000',\n", - " 'utt2spk': '116-288045'}\n", - "5542\n", - "\n" - ] - } - ], - "source": [ - "from pprint import pprint\n", - "pprint(dev_json[0])\n", - "print(len(dev_json))\n", - "print(type(dev_json))" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "id": "motivated-receptor", - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "import itertools\n", - "\n", - "import numpy as np\n", - "\n", - "from deepspeech.utils.log import Log\n", - "\n", - "__all__ = [\"make_batchset\"]\n", - "\n", - "logger = Log(__name__).getlog()\n", - "\n", - "\n", - "def batchfy_by_seq(\n", - " sorted_data,\n", - " batch_size,\n", - " max_length_in,\n", - " max_length_out,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " ikey=\"input\",\n", - " iaxis=0,\n", - " okey=\"output\",\n", - " oaxis=0, ):\n", - " \"\"\"Make batch set from json dictionary\n", - "\n", - " :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n", - " :param int batch_size: batch size\n", - " :param int max_length_in: maximum length of input to decide adaptive batch size\n", - " :param int max_length_out: maximum length of output to decide adaptive batch size\n", - " :param int min_batch_size: mininum batch size (for multi-gpu)\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - " :param str ikey: key to access input\n", - " (for ASR ikey=\"input\", for TTS, MT ikey=\"output\".)\n", - " :param int iaxis: dimension to access input\n", - " (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n", - " :param str okey: key to access output\n", - " (for ASR, MT okey=\"output\". for TTS okey=\"input\".)\n", - " :param int oaxis: dimension to access output\n", - " (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)\n", - " :return: List[List[Tuple[str, dict]]] list of batches\n", - " \"\"\"\n", - " if batch_size <= 0:\n", - " raise ValueError(f\"Invalid batch_size={batch_size}\")\n", - "\n", - " # check #utts is more than min_batch_size\n", - " if len(sorted_data) < min_batch_size:\n", - " raise ValueError(\n", - " f\"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size}).\"\n", - " )\n", - "\n", - " # make list of minibatches\n", - " minibatches = []\n", - " start = 0\n", - " while True:\n", - " _, info = sorted_data[start]\n", - " ilen = int(info[ikey][iaxis][\"shape\"][0])\n", - " olen = (int(info[okey][oaxis][\"shape\"][0]) if oaxis >= 0 else\n", - " max(map(lambda x: int(x[\"shape\"][0]), info[okey])))\n", - " factor = max(int(ilen / max_length_in), int(olen / max_length_out))\n", - " # change batchsize depending on the input and output length\n", - " # if ilen = 1000 and max_length_in = 800\n", - " # then b = batchsize / 2\n", - " # and max(min_batches, .) avoids batchsize = 0\n", - " bs = max(min_batch_size, int(batch_size / (1 + factor)))\n", - " end = min(len(sorted_data), start + bs)\n", - " minibatch = sorted_data[start:end]\n", - " if shortest_first:\n", - " minibatch.reverse()\n", - "\n", - " # check each batch is more than minimum batchsize\n", - " if len(minibatch) < min_batch_size:\n", - " mod = min_batch_size - len(minibatch) % min_batch_size\n", - " additional_minibatch = [\n", - " sorted_data[i] for i in np.random.randint(0, start, mod)\n", - " ]\n", - " if shortest_first:\n", - " additional_minibatch.reverse()\n", - " minibatch.extend(additional_minibatch)\n", - " minibatches.append(minibatch)\n", - "\n", - " if end == len(sorted_data):\n", - " break\n", - " start = end\n", - "\n", - " # batch: List[List[Tuple[str, dict]]]\n", - " return minibatches\n", - "\n", - "\n", - "def batchfy_by_bin(\n", - " sorted_data,\n", - " batch_bins,\n", - " num_batches=0,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " ikey=\"input\",\n", - " okey=\"output\", ):\n", - " \"\"\"Make variably sized batch set, which maximizes\n", - "\n", - " the number of bins up to `batch_bins`.\n", - "\n", - " :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n", - " :param int batch_bins: Maximum frames of a batch\n", - " :param int num_batches: # number of batches to use (for debug)\n", - " :param int min_batch_size: minimum batch size (for multi-gpu)\n", - " :param int test: Return only every `test` batches\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - "\n", - " :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n", - " :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n", - "\n", - " :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n", - " \"\"\"\n", - " if batch_bins <= 0:\n", - " raise ValueError(f\"invalid batch_bins={batch_bins}\")\n", - " length = len(sorted_data)\n", - " idim = int(sorted_data[0][1][ikey][0][\"shape\"][1])\n", - " odim = int(sorted_data[0][1][okey][0][\"shape\"][1])\n", - " logger.info(\"# utts: \" + str(len(sorted_data)))\n", - " minibatches = []\n", - " start = 0\n", - " n = 0\n", - " while True:\n", - " # Dynamic batch size depending on size of samples\n", - " b = 0\n", - " next_size = 0\n", - " max_olen = 0\n", - " while next_size < batch_bins and (start + b) < length:\n", - " ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0]) * idim\n", - " olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0]) * odim\n", - " if olen > max_olen:\n", - " max_olen = olen\n", - " next_size = (max_olen + ilen) * (b + 1)\n", - " if next_size <= batch_bins:\n", - " b += 1\n", - " elif next_size == 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in batch_bins ({batch_bins}): \"\n", - " f\"Please increase the value\")\n", - " end = min(length, start + max(min_batch_size, b))\n", - " batch = sorted_data[start:end]\n", - " if shortest_first:\n", - " batch.reverse()\n", - " minibatches.append(batch)\n", - " # Check for min_batch_size and fixes the batches if needed\n", - " i = -1\n", - " while len(minibatches[i]) < min_batch_size:\n", - " missing = min_batch_size - len(minibatches[i])\n", - " if -i == len(minibatches):\n", - " minibatches[i + 1].extend(minibatches[i])\n", - " minibatches = minibatches[1:]\n", - " break\n", - " else:\n", - " minibatches[i].extend(minibatches[i - 1][:missing])\n", - " minibatches[i - 1] = minibatches[i - 1][missing:]\n", - " i -= 1\n", - " if end == length:\n", - " break\n", - " start = end\n", - " n += 1\n", - " if num_batches > 0:\n", - " minibatches = minibatches[:num_batches]\n", - " lengths = [len(x) for x in minibatches]\n", - " logger.info(\n", - " str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n", - " + \" to \" + str(max(lengths)) + \" samples \" + \"(avg \" + str(\n", - " int(np.mean(lengths))) + \" samples).\")\n", - " return minibatches\n", - "\n", - "\n", - "def batchfy_by_frame(\n", - " sorted_data,\n", - " max_frames_in,\n", - " max_frames_out,\n", - " max_frames_inout,\n", - " num_batches=0,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " ikey=\"input\",\n", - " okey=\"output\", ):\n", - " \"\"\"Make variable batch set, which maximizes the number of frames to max_batch_frame.\n", - "\n", - " :param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json\n", - " :param int max_frames_in: Maximum input frames of a batch\n", - " :param int max_frames_out: Maximum output frames of a batch\n", - " :param int max_frames_inout: Maximum input+output frames of a batch\n", - " :param int num_batches: # number of batches to use (for debug)\n", - " :param int min_batch_size: minimum batch size (for multi-gpu)\n", - " :param int test: Return only every `test` batches\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - "\n", - " :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n", - " :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n", - "\n", - " :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n", - " \"\"\"\n", - " if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:\n", - " raise ValueError(\n", - " \"At least, one of `--batch-frames-in`, `--batch-frames-out` or \"\n", - " \"`--batch-frames-inout` should be > 0\")\n", - " length = len(sorted_data)\n", - " minibatches = []\n", - " start = 0\n", - " end = 0\n", - " while end != length:\n", - " # Dynamic batch size depending on size of samples\n", - " b = 0\n", - " max_olen = 0\n", - " max_ilen = 0\n", - " while (start + b) < length:\n", - " ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0])\n", - " if ilen > max_frames_in and max_frames_in != 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in --batch-frames-in ({max_frames_in}): \"\n", - " f\"Please increase the value\")\n", - " olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0])\n", - " if olen > max_frames_out and max_frames_out != 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in --batch-frames-out ({max_frames_out}): \"\n", - " f\"Please increase the value\")\n", - " if ilen + olen > max_frames_inout and max_frames_inout != 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in --batch-frames-out ({max_frames_inout}): \"\n", - " f\"Please increase the value\")\n", - " max_olen = max(max_olen, olen)\n", - " max_ilen = max(max_ilen, ilen)\n", - " in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0\n", - " out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0\n", - " inout_ok = (max_ilen + max_olen) * (\n", - " b + 1) <= max_frames_inout or max_frames_inout == 0\n", - " if in_ok and out_ok and inout_ok:\n", - " # add more seq in the minibatch\n", - " b += 1\n", - " else:\n", - " # no more seq in the minibatch\n", - " break\n", - " end = min(length, start + b)\n", - " batch = sorted_data[start:end]\n", - " if shortest_first:\n", - " batch.reverse()\n", - " minibatches.append(batch)\n", - " # Check for min_batch_size and fixes the batches if needed\n", - " i = -1\n", - " while len(minibatches[i]) < min_batch_size:\n", - " missing = min_batch_size - len(minibatches[i])\n", - " if -i == len(minibatches):\n", - " minibatches[i + 1].extend(minibatches[i])\n", - " minibatches = minibatches[1:]\n", - " break\n", - " else:\n", - " minibatches[i].extend(minibatches[i - 1][:missing])\n", - " minibatches[i - 1] = minibatches[i - 1][missing:]\n", - " i -= 1\n", - " start = end\n", - " if num_batches > 0:\n", - " minibatches = minibatches[:num_batches]\n", - " lengths = [len(x) for x in minibatches]\n", - " logger.info(\n", - " str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n", - " + \" to \" + str(max(lengths)) + \" samples\" + \"(avg \" + str(\n", - " int(np.mean(lengths))) + \" samples).\")\n", - "\n", - " return minibatches\n", - "\n", - "\n", - "def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,\n", - " shortest_first):\n", - " import random\n", - "\n", - " logger.info(\"use shuffled batch.\")\n", - " sorted_data = random.sample(data.items(), len(data.items()))\n", - " logger.info(\"# utts: \" + str(len(sorted_data)))\n", - " # make list of minibatches\n", - " minibatches = []\n", - " start = 0\n", - " while True:\n", - " end = min(len(sorted_data), start + batch_size)\n", - " # check each batch is more than minimum batchsize\n", - " minibatch = sorted_data[start:end]\n", - " if shortest_first:\n", - " minibatch.reverse()\n", - " if len(minibatch) < min_batch_size:\n", - " mod = min_batch_size - len(minibatch) % min_batch_size\n", - " additional_minibatch = [\n", - " sorted_data[i] for i in np.random.randint(0, start, mod)\n", - " ]\n", - " if shortest_first:\n", - " additional_minibatch.reverse()\n", - " minibatch.extend(additional_minibatch)\n", - " minibatches.append(minibatch)\n", - " if end == len(sorted_data):\n", - " break\n", - " start = end\n", - "\n", - " # for debugging\n", - " if num_batches > 0:\n", - " minibatches = minibatches[:num_batches]\n", - " logger.info(\"# minibatches: \" + str(len(minibatches)))\n", - " return minibatches\n", - "\n", - "\n", - "BATCH_COUNT_CHOICES = [\"auto\", \"seq\", \"bin\", \"frame\"]\n", - "BATCH_SORT_KEY_CHOICES = [\"input\", \"output\", \"shuffle\"]\n", - "\n", - "\n", - "def make_batchset(\n", - " data,\n", - " batch_size=0,\n", - " max_length_in=float(\"inf\"),\n", - " max_length_out=float(\"inf\"),\n", - " num_batches=0,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " batch_sort_key=\"input\",\n", - " count=\"auto\",\n", - " batch_bins=0,\n", - " batch_frames_in=0,\n", - " batch_frames_out=0,\n", - " batch_frames_inout=0,\n", - " iaxis=0,\n", - " oaxis=0, ):\n", - " \"\"\"Make batch set from json dictionary\n", - "\n", - " if utts have \"category\" value,\n", - "\n", - " >>> data = {'utt1': {'category': 'A', 'input': ...},\n", - " ... 'utt2': {'category': 'B', 'input': ...},\n", - " ... 'utt3': {'category': 'B', 'input': ...},\n", - " ... 'utt4': {'category': 'A', 'input': ...}}\n", - " >>> make_batchset(data, batchsize=2, ...)\n", - " [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]\n", - "\n", - " Note that if any utts doesn't have \"category\",\n", - " perform as same as batchfy_by_{count}\n", - "\n", - " :param List[Dict[str, Any]] data: dictionary loaded from data.json\n", - " :param int batch_size: maximum number of sequences in a minibatch.\n", - " :param int batch_bins: maximum number of bins (frames x dim) in a minibatch.\n", - " :param int batch_frames_in: maximum number of input frames in a minibatch.\n", - " :param int batch_frames_out: maximum number of output frames in a minibatch.\n", - " :param int batch_frames_out: maximum number of input+output frames in a minibatch.\n", - " :param str count: strategy to count maximum size of batch.\n", - " For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES\n", - "\n", - " :param int max_length_in: maximum length of input to decide adaptive batch size\n", - " :param int max_length_out: maximum length of output to decide adaptive batch size\n", - " :param int num_batches: # number of batches to use (for debug)\n", - " :param int min_batch_size: minimum batch size (for multi-gpu)\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - " :param str batch_sort_key: how to sort data before creating minibatches\n", - " [\"input\", \"output\", \"shuffle\"]\n", - " :param bool swap_io: if True, use \"input\" as output and \"output\"\n", - " as input in `data` dict\n", - " :param bool mt: if True, use 0-axis of \"output\" as output and 1-axis of \"output\"\n", - " as input in `data` dict\n", - " :param int iaxis: dimension to access input\n", - " (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n", - " :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,\n", - " reserved for future research, -1 means all axis.)\n", - " :return: List[List[Tuple[str, dict]]] list of batches\n", - " \"\"\"\n", - "\n", - " # check args\n", - " if count not in BATCH_COUNT_CHOICES:\n", - " raise ValueError(\n", - " f\"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}\")\n", - " if batch_sort_key not in BATCH_SORT_KEY_CHOICES:\n", - " raise ValueError(f\"arg 'batch_sort_key' ({batch_sort_key}) should be \"\n", - " f\"one of {BATCH_SORT_KEY_CHOICES}\")\n", - "\n", - " ikey = \"input\"\n", - " okey = \"output\"\n", - " batch_sort_axis = 0 # index of list \n", - "\n", - " if count == \"auto\":\n", - " if batch_size != 0:\n", - " count = \"seq\"\n", - " elif batch_bins != 0:\n", - " count = \"bin\"\n", - " elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:\n", - " count = \"frame\"\n", - " else:\n", - " raise ValueError(\n", - " f\"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}\"\n", - " )\n", - " logger.info(f\"count is auto detected as {count}\")\n", - "\n", - " if count != \"seq\" and batch_sort_key == \"shuffle\":\n", - " raise ValueError(\n", - " \"batch_sort_key=shuffle is only available if batch_count=seq\")\n", - "\n", - " category2data = {} # Dict[str, dict]\n", - " for v in data:\n", - " k = v['utt']\n", - " category2data.setdefault(v.get(\"category\"), {})[k] = v\n", - "\n", - " batches_list = [] # List[List[List[Tuple[str, dict]]]]\n", - " for d in category2data.values():\n", - " if batch_sort_key == \"shuffle\":\n", - " batches = batchfy_shuffle(d, batch_size, min_batch_size,\n", - " num_batches, shortest_first)\n", - " batches_list.append(batches)\n", - " continue\n", - "\n", - " # sort it by input lengths (long to short)\n", - " sorted_data = sorted(\n", - " d.items(),\n", - " key=lambda data: int(data[1][batch_sort_key][batch_sort_axis][\"shape\"][0]),\n", - " reverse=not shortest_first, )\n", - " logger.info(\"# utts: \" + str(len(sorted_data)))\n", - " \n", - " if count == \"seq\":\n", - " batches = batchfy_by_seq(\n", - " sorted_data,\n", - " batch_size=batch_size,\n", - " max_length_in=max_length_in,\n", - " max_length_out=max_length_out,\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=shortest_first,\n", - " ikey=ikey,\n", - " iaxis=iaxis,\n", - " okey=okey,\n", - " oaxis=oaxis, )\n", - " if count == \"bin\":\n", - " batches = batchfy_by_bin(\n", - " sorted_data,\n", - " batch_bins=batch_bins,\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=shortest_first,\n", - " ikey=ikey,\n", - " okey=okey, )\n", - " if count == \"frame\":\n", - " batches = batchfy_by_frame(\n", - " sorted_data,\n", - " max_frames_in=batch_frames_in,\n", - " max_frames_out=batch_frames_out,\n", - " max_frames_inout=batch_frames_inout,\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=shortest_first,\n", - " ikey=ikey,\n", - " okey=okey, )\n", - " batches_list.append(batches)\n", - "\n", - " if len(batches_list) == 1:\n", - " batches = batches_list[0]\n", - " else:\n", - " # Concat list. This way is faster than \"sum(batch_list, [])\"\n", - " batches = list(itertools.chain(*batches_list))\n", - "\n", - " # for debugging\n", - " if num_batches > 0:\n", - " batches = batches[:num_batches]\n", - " logger.info(\"# minibatches: \" + str(len(batches)))\n", - "\n", - " # batch: List[List[Tuple[str, dict]]]\n", - " return batches\n" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "id": "acquired-hurricane", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[INFO 2021/08/18 06:57:10 1445365138.py:284] use shuffled batch.\n", - "[INFO 2021/08/18 06:57:10 1445365138.py:286] # utts: 5542\n", - "[INFO 2021/08/18 06:57:10 1445365138.py:468] # minibatches: 555\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "555\n" - ] - } - ], - "source": [ - "batch_size=10\n", - "maxlen_in=300\n", - "maxlen_out=400\n", - "minibatches=0 # for debug\n", - "min_batch_size=2\n", - "use_sortagrad=True\n", - "batch_count='seq'\n", - "batch_bins=0\n", - "batch_frames_in=3000\n", - "batch_frames_out=0\n", - "batch_frames_inout=0\n", - " \n", - "dev_data = make_batchset(\n", - " dev_json,\n", - " batch_size,\n", - " maxlen_in,\n", - " maxlen_out,\n", - " minibatches, # for debug\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=use_sortagrad,\n", - " batch_sort_key=\"shuffle\",\n", - " count=batch_count,\n", - " batch_bins=batch_bins,\n", - " batch_frames_in=batch_frames_in,\n", - " batch_frames_out=batch_frames_out,\n", - " batch_frames_inout=batch_frames_inout,\n", - " iaxis=0,\n", - " oaxis=0, )\n", - "print(len(dev_data))\n", - "# for i in range(len(dev_data)):\n", - "# print(len(dev_data[i]))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "id": "warming-malpractice", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: kaldiio in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (2.17.2)\n", - "Requirement already satisfied: numpy in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numpy-1.21.2-py3.7-linux-x86_64.egg (from kaldiio) (1.21.2)\n", - "\u001b[33mWARNING: You are using pip version 20.3.3; however, version 21.2.4 is available.\n", - "You should consider upgrading via the '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/bin/python -m pip install --upgrade pip' command.\u001b[0m\n" - ] - } - ], - "source": [ - "!pip install kaldiio" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "equipped-subject", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 100, - "id": "superb-methodology", - "metadata": {}, - "outputs": [], - "source": [ - "from collections import OrderedDict\n", - "import kaldiio\n", - "\n", - "class LoadInputsAndTargets():\n", - " \"\"\"Create a mini-batch from a list of dicts\n", - "\n", - " >>> batch = [('utt1',\n", - " ... dict(input=[dict(feat='some.ark:123',\n", - " ... filetype='mat',\n", - " ... name='input1',\n", - " ... shape=[100, 80])],\n", - " ... output=[dict(tokenid='1 2 3 4',\n", - " ... name='target1',\n", - " ... shape=[4, 31])]]))\n", - " >>> l = LoadInputsAndTargets()\n", - " >>> feat, target = l(batch)\n", - "\n", - " :param: str mode: Specify the task mode, \"asr\" or \"tts\"\n", - " :param: str preprocess_conf: The path of a json file for pre-processing\n", - " :param: bool load_input: If False, not to load the input data\n", - " :param: bool load_output: If False, not to load the output data\n", - " :param: bool sort_in_input_length: Sort the mini-batch in descending order\n", - " of the input length\n", - " :param: bool use_speaker_embedding: Used for tts mode only\n", - " :param: bool use_second_target: Used for tts mode only\n", - " :param: dict preprocess_args: Set some optional arguments for preprocessing\n", - " :param: Optional[dict] preprocess_args: Used for tts mode only\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " mode=\"asr\",\n", - " preprocess_conf=None,\n", - " load_input=True,\n", - " load_output=True,\n", - " sort_in_input_length=True,\n", - " preprocess_args=None,\n", - " keep_all_data_on_mem=False, ):\n", - " self._loaders = {}\n", - "\n", - " if mode not in [\"asr\"]:\n", - " raise ValueError(\"Only asr are allowed: mode={}\".format(mode))\n", - "\n", - " if preprocess_conf is not None:\n", - " self.preprocessing = AugmentationPipeline(preprocess_conf)\n", - " logging.warning(\n", - " \"[Experimental feature] Some preprocessing will be done \"\n", - " \"for the mini-batch creation using {}\".format(\n", - " self.preprocessing))\n", - " else:\n", - " # If conf doesn't exist, this function don't touch anything.\n", - " self.preprocessing = None\n", - "\n", - " self.mode = mode\n", - " self.load_output = load_output\n", - " self.load_input = load_input\n", - " self.sort_in_input_length = sort_in_input_length\n", - " if preprocess_args is None:\n", - " self.preprocess_args = {}\n", - " else:\n", - " assert isinstance(preprocess_args, dict), type(preprocess_args)\n", - " self.preprocess_args = dict(preprocess_args)\n", - "\n", - " self.keep_all_data_on_mem = keep_all_data_on_mem\n", - "\n", - " def __call__(self, batch, return_uttid=False):\n", - " \"\"\"Function to load inputs and targets from list of dicts\n", - "\n", - " :param List[Tuple[str, dict]] batch: list of dict which is subset of\n", - " loaded data.json\n", - " :param bool return_uttid: return utterance ID information for visualization\n", - " :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]\n", - " :return: list of input feature sequences\n", - " [(T_1, D), (T_2, D), ..., (T_B, D)]\n", - " :rtype: list of float ndarray\n", - " :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]\n", - " :rtype: list of int ndarray\n", - "\n", - " \"\"\"\n", - " x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n", - " y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n", - " uttid_list = [] # List[str]\n", - "\n", - " for uttid, info in batch:\n", - " uttid_list.append(uttid)\n", - "\n", - " if self.load_input:\n", - " # Note(kamo): This for-loop is for multiple inputs\n", - " for idx, inp in enumerate(info[\"input\"]):\n", - " # {\"input\":\n", - " # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"hdf5\",\n", - " # \"name\": \"input1\", ...}], ...}\n", - " x = self._get_from_loader(\n", - " filepath=inp[\"feat\"],\n", - " filetype=inp.get(\"filetype\", \"mat\"))\n", - " x_feats_dict.setdefault(inp[\"name\"], []).append(x)\n", - "\n", - " if self.load_output:\n", - " for idx, inp in enumerate(info[\"output\"]):\n", - " if \"tokenid\" in inp:\n", - " # ======= Legacy format for output =======\n", - " # {\"output\": [{\"tokenid\": \"1 2 3 4\"}])\n", - " x = np.fromiter(\n", - " map(int, inp[\"tokenid\"].split()), dtype=np.int64)\n", - " else:\n", - " # ======= New format =======\n", - " # {\"input\":\n", - " # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"hdf5\",\n", - " # \"name\": \"target1\", ...}], ...}\n", - " x = self._get_from_loader(\n", - " filepath=inp[\"feat\"],\n", - " filetype=inp.get(\"filetype\", \"mat\"))\n", - "\n", - " y_feats_dict.setdefault(inp[\"name\"], []).append(x)\n", - "\n", - " if self.mode == \"asr\":\n", - " return_batch, uttid_list = self._create_batch_asr(\n", - " x_feats_dict, y_feats_dict, uttid_list)\n", - " else:\n", - " raise NotImplementedError(self.mode)\n", - "\n", - " if self.preprocessing is not None:\n", - " # Apply pre-processing all input features\n", - " for x_name in return_batch.keys():\n", - " if x_name.startswith(\"input\"):\n", - " return_batch[x_name] = self.preprocessing(\n", - " return_batch[x_name], uttid_list,\n", - " **self.preprocess_args)\n", - "\n", - " if return_uttid:\n", - " return tuple(return_batch.values()), uttid_list\n", - "\n", - " # Doesn't return the names now.\n", - " return tuple(return_batch.values())\n", - "\n", - " def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):\n", - " \"\"\"Create a OrderedDict for the mini-batch\n", - "\n", - " :param OrderedDict x_feats_dict:\n", - " e.g. {\"input1\": [ndarray, ndarray, ...],\n", - " \"input2\": [ndarray, ndarray, ...]}\n", - " :param OrderedDict y_feats_dict:\n", - " e.g. {\"target1\": [ndarray, ndarray, ...],\n", - " \"target2\": [ndarray, ndarray, ...]}\n", - " :param: List[str] uttid_list:\n", - " Give uttid_list to sort in the same order as the mini-batch\n", - " :return: batch, uttid_list\n", - " :rtype: Tuple[OrderedDict, List[str]]\n", - " \"\"\"\n", - " # handle single-input and multi-input (paralell) asr mode\n", - " xs = list(x_feats_dict.values())\n", - "\n", - " if self.load_output:\n", - " ys = list(y_feats_dict.values())\n", - " assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))\n", - "\n", - " # get index of non-zero length samples\n", - " nonzero_idx = list(\n", - " filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))\n", - " for n in range(1, len(y_feats_dict)):\n", - " nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)\n", - " else:\n", - " # Note(kamo): Be careful not to make nonzero_idx to a generator\n", - " nonzero_idx = list(range(len(xs[0])))\n", - "\n", - " if self.sort_in_input_length:\n", - " # sort in input lengths based on the first input\n", - " nonzero_sorted_idx = sorted(\n", - " nonzero_idx, key=lambda i: -len(xs[0][i]))\n", - " else:\n", - " nonzero_sorted_idx = nonzero_idx\n", - "\n", - " if len(nonzero_sorted_idx) != len(xs[0]):\n", - " logging.warning(\n", - " \"Target sequences include empty tokenid (batch {} -> {}).\".\n", - " format(len(xs[0]), len(nonzero_sorted_idx)))\n", - "\n", - " # remove zero-length samples\n", - " xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]\n", - " uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]\n", - "\n", - " x_names = list(x_feats_dict.keys())\n", - " if self.load_output:\n", - " ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]\n", - " y_names = list(y_feats_dict.keys())\n", - "\n", - " # Keeping x_name and y_name, e.g. input1, for future extension\n", - " return_batch = OrderedDict([\n", - " * [(x_name, x) for x_name, x in zip(x_names, xs)],\n", - " * [(y_name, y) for y_name, y in zip(y_names, ys)],\n", - " ])\n", - " else:\n", - " return_batch = OrderedDict(\n", - " [(x_name, x) for x_name, x in zip(x_names, xs)])\n", - " return return_batch, uttid_list\n", - "\n", - " def _get_from_loader(self, filepath, filetype):\n", - " \"\"\"Return ndarray\n", - "\n", - " In order to make the fds to be opened only at the first referring,\n", - " the loader are stored in self._loaders\n", - "\n", - " >>> ndarray = loader.get_from_loader(\n", - " ... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')\n", - "\n", - " :param: str filepath:\n", - " :param: str filetype:\n", - " :return:\n", - " :rtype: np.ndarray\n", - " \"\"\"\n", - " if filetype == \"hdf5\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"hdf5\",\n", - " # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n", - " filepath, key = filepath.split(\":\", 1)\n", - "\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = h5py.File(filepath, \"r\")\n", - " self._loaders[filepath] = loader\n", - " return loader[key][()]\n", - " elif filetype == \"sound.hdf5\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"sound.hdf5\",\n", - " # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n", - " filepath, key = filepath.split(\":\", 1)\n", - "\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = SoundHDF5File(filepath, \"r\", dtype=\"int16\")\n", - " self._loaders[filepath] = loader\n", - " array, rate = loader[key]\n", - " return array\n", - " elif filetype == \"sound\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.wav\",\n", - " # \"filetype\": \"sound\"},\n", - " # Assume PCM16\n", - " if not self.keep_all_data_on_mem:\n", - " array, _ = soundfile.read(filepath, dtype=\"int16\")\n", - " return array\n", - " if filepath not in self._loaders:\n", - " array, _ = soundfile.read(filepath, dtype=\"int16\")\n", - " self._loaders[filepath] = array\n", - " return self._loaders[filepath]\n", - " elif filetype == \"npz\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.npz:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"npz\",\n", - " filepath, key = filepath.split(\":\", 1)\n", - "\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = np.load(filepath)\n", - " self._loaders[filepath] = loader\n", - " return loader[key]\n", - " elif filetype == \"npy\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.npy\",\n", - " # \"filetype\": \"npy\"},\n", - " if not self.keep_all_data_on_mem:\n", - " return np.load(filepath)\n", - " if filepath not in self._loaders:\n", - " self._loaders[filepath] = np.load(filepath)\n", - " return self._loaders[filepath]\n", - " elif filetype in [\"mat\", \"vec\"]:\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.ark:123\",\n", - " # \"filetype\": \"mat\"}]},\n", - " # In this case, \"123\" indicates the starting points of the matrix\n", - " # load_mat can load both matrix and vector\n", - " if not self.keep_all_data_on_mem:\n", - " return kaldiio.load_mat(filepath)\n", - " if filepath not in self._loaders:\n", - " self._loaders[filepath] = kaldiio.load_mat(filepath)\n", - " return self._loaders[filepath]\n", - " elif filetype == \"scp\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.scp:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"scp\",\n", - " filepath, key = filepath.split(\":\", 1)\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = kaldiio.load_scp(filepath)\n", - " self._loaders[filepath] = loader\n", - " return loader[key]\n", - " else:\n", - " raise NotImplementedError(\n", - " \"Not supported: loader_type={}\".format(filetype))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "id": "monthly-muscle", - "metadata": {}, - "outputs": [], - "source": [ - "preprocess_conf=None\n", - "train_mode=True\n", - "load = LoadInputsAndTargets(\n", - " mode=\"asr\",\n", - " load_output=True,\n", - " preprocess_conf=preprocess_conf,\n", - " preprocess_args={\"train\":\n", - " train_mode}, # Switch the mode of preprocessing\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "id": "periodic-senegal", - "metadata": {}, - "outputs": [], - "source": [ - "res = load(dev_data[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 103, - "id": "502d3f4d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "2\n", - "10\n", - "10\n", - "(1174, 83) float32\n", - "(29,) int64\n" - ] - } - ], - "source": [ - "print(type(res))\n", - "print(len(res))\n", - "print(len(res[0]))\n", - "print(len(res[1]))\n", - "print(res[0][0].shape, res[0][0].dtype)\n", - "print(res[1][0].shape, res[1][0].dtype)\n", - "# Tuple[Tuple[np.ndarry], Tuple[np.ndarry]]\n", - "# 2[10, 10]\n", - "# feats, labels" - ] - }, - { - "cell_type": "code", - "execution_count": 104, - "id": "humanitarian-container", - "metadata": {}, - "outputs": [], - "source": [ - "(inputs, outputs), utts = load(dev_data[0], return_uttid=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 105, - "id": "heard-prize", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027'] 10\n", - "10\n" - ] - } - ], - "source": [ - "print(utts, len(utts))\n", - "print(len(inputs))" - ] - }, - { - "cell_type": "code", - "execution_count": 106, - "id": "convinced-animation", - "metadata": {}, - "outputs": [], - "source": [ - "import paddle\n", - "from deepspeech.io.utility import pad_list\n", - "class CustomConverter():\n", - " \"\"\"Custom batch converter.\n", - "\n", - " Args:\n", - " subsampling_factor (int): The subsampling factor.\n", - " dtype (paddle.dtype): Data type to convert.\n", - "\n", - " \"\"\"\n", - "\n", - " def __init__(self, subsampling_factor=1, dtype=np.float32):\n", - " \"\"\"Construct a CustomConverter object.\"\"\"\n", - " self.subsampling_factor = subsampling_factor\n", - " self.ignore_id = -1\n", - " self.dtype = dtype\n", - "\n", - " def __call__(self, batch):\n", - " \"\"\"Transform a batch and send it to a device.\n", - "\n", - " Args:\n", - " batch (list): The batch to transform.\n", - "\n", - " Returns:\n", - " tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)\n", - "\n", - " \"\"\"\n", - " # batch should be located in list\n", - " assert len(batch) == 1\n", - " (xs, ys), utts = batch[0]\n", - "\n", - " # perform subsampling\n", - " if self.subsampling_factor > 1:\n", - " xs = [x[::self.subsampling_factor, :] for x in xs]\n", - "\n", - " # get batch of lengths of input sequences\n", - " ilens = np.array([x.shape[0] for x in xs])\n", - "\n", - " # perform padding and convert to tensor\n", - " # currently only support real number\n", - " if xs[0].dtype.kind == \"c\":\n", - " xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)\n", - " xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)\n", - " # Note(kamo):\n", - " # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.\n", - " # Don't create ComplexTensor and give it E2E here\n", - " # because torch.nn.DataParellel can't handle it.\n", - " xs_pad = {\"real\": xs_pad_real, \"imag\": xs_pad_imag}\n", - " else:\n", - " xs_pad = pad_list(xs, 0).astype(self.dtype)\n", - "\n", - " # NOTE: this is for multi-output (e.g., speech translation)\n", - " ys_pad = pad_list(\n", - " [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],\n", - " self.ignore_id)\n", - "\n", - " olens = np.array([y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])\n", - " return utts, xs_pad, ilens, ys_pad, olens" - ] - }, - { - "cell_type": "code", - "execution_count": 107, - "id": "0b92ade5", - "metadata": {}, - "outputs": [], - "source": [ - "convert = CustomConverter()" - ] - }, - { - "cell_type": "code", - "execution_count": 108, - "id": "8dbd847c", - "metadata": {}, - "outputs": [], - "source": [ - "utts, xs, ilen, ys, olen = convert([load(dev_data[0], return_uttid=True)])" - ] - }, - { - "cell_type": "code", - "execution_count": 109, - "id": "31c085f4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027']\n", - "(10, 1174, 83)\n", - "(10,)\n", - "[1174 821 716 628 597 473 463 441 419 358]\n", - "(10, 32)\n", - "[[4502 2404 4223 3204 4502 587 1018 3861 2932 713 2458 2916 253 4508\n", - " 627 1395 713 4504 957 2761 209 2967 3173 3918 2598 4100 3 2816\n", - " 4990 -1 -1 -1]\n", - " [1005 451 210 278 3411 206 482 2307 573 4502 3848 4577 4273 2388\n", - " 4444 89 4919 278 1264 4501 2371 3 139 113 2603 4962 3158 3325\n", - " 4577 814 4587 1422]\n", - " [2345 4144 2291 200 713 2345 532 999 2458 3076 545 2458 4832 3038\n", - " 4499 482 2812 1260 3080 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [2345 832 4577 4920 4501 2345 2298 1236 381 288 389 101 2495 4172\n", - " 4843 3233 3245 4501 2345 2298 3987 4502 3023 3353 2345 1361 1635 2603\n", - " 4723 2371 -1 -1]\n", - " [4502 4207 432 3204 4502 2396 125 935 433 2598 483 18 327 2\n", - " 389 627 4512 2340 713 482 1981 4525 4031 269 2030 1340 101 2495\n", - " 4013 4844 -1 -1]\n", - " [4502 4892 3204 1892 3780 389 482 2774 3013 89 192 2495 4502 3475\n", - " 389 66 370 343 404 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [2458 2314 4577 2340 2863 1254 303 269 2 389 932 2079 4577 299\n", - " 195 3233 4508 2 89 814 3144 1091 3204 3250 2193 3414 -1 -1\n", - " -1 -1 -1 -1]\n", - " [2391 1785 443 78 39 4962 2340 829 599 4593 278 4681 202 407\n", - " 269 194 182 4577 482 4308 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [ 627 4873 2175 363 202 404 1018 4577 4502 3412 4875 2286 107 122\n", - " 4832 2345 3896 89 2368 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [ 481 174 474 599 1881 3252 2842 742 4502 2545 107 88 3204 4525\n", - " 4517 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]]\n", - "[29 32 19 30 30 19 26 20 19 15]\n", - "float32\n", - "int64\n", - "int64\n", - "int64\n" - ] - } - ], - "source": [ - "print(utts)\n", - "print(xs.shape)\n", - "print(ilen.shape)\n", - "print(ilen)\n", - "print(ys.shape)\n", - "print(ys)\n", - "print(olen)\n", - "print(xs.dtype)\n", - "print(ilen.dtype)\n", - "print(ys.dtype)\n", - "print(olen.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 110, - "id": "72e9ba60", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 230, - "id": "64593e5f", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "from paddle.io import DataLoader\n", - "\n", - "from deepspeech.frontend.utility import read_manifest\n", - "from deepspeech.io.batchfy import make_batchset\n", - "from deepspeech.io.converter import CustomConverter\n", - "from deepspeech.io.dataset import TransformDataset\n", - "from deepspeech.io.reader import LoadInputsAndTargets\n", - "from deepspeech.utils.log import Log\n", - "\n", - "\n", - "logger = Log(__name__).getlog()\n", - "\n", - "\n", - "class BatchDataLoader():\n", - " def __init__(self,\n", - " json_file: str,\n", - " train_mode: bool,\n", - " sortagrad: bool=False,\n", - " batch_size: int=0,\n", - " maxlen_in: float=float('inf'),\n", - " maxlen_out: float=float('inf'),\n", - " minibatches: int=0,\n", - " mini_batch_size: int=1,\n", - " batch_count: str='auto',\n", - " batch_bins: int=0,\n", - " batch_frames_in: int=0,\n", - " batch_frames_out: int=0,\n", - " batch_frames_inout: int=0,\n", - " preprocess_conf=None,\n", - " n_iter_processes: int=1,\n", - " subsampling_factor: int=1,\n", - " num_encs: int=1):\n", - " self.json_file = json_file\n", - " self.train_mode = train_mode\n", - " self.use_sortagrad = sortagrad == -1 or sortagrad > 0\n", - " self.batch_size = batch_size\n", - " self.maxlen_in = maxlen_in\n", - " self.maxlen_out = maxlen_out\n", - " self.batch_count = batch_count\n", - " self.batch_bins = batch_bins\n", - " self.batch_frames_in = batch_frames_in\n", - " self.batch_frames_out = batch_frames_out\n", - " self.batch_frames_inout = batch_frames_inout\n", - " self.subsampling_factor = subsampling_factor\n", - " self.num_encs = num_encs\n", - " self.preprocess_conf = preprocess_conf\n", - " self.n_iter_processes = n_iter_processes\n", - "\n", - " \n", - " # read json data\n", - " self.data_json = read_manifest(json_file)\n", - "\n", - " # make minibatch list (variable length)\n", - " self.minibaches = make_batchset(\n", - " self.data_json,\n", - " batch_size,\n", - " maxlen_in,\n", - " maxlen_out,\n", - " minibatches, # for debug\n", - " min_batch_size=mini_batch_size,\n", - " shortest_first=self.use_sortagrad,\n", - " count=batch_count,\n", - " batch_bins=batch_bins,\n", - " batch_frames_in=batch_frames_in,\n", - " batch_frames_out=batch_frames_out,\n", - " batch_frames_inout=batch_frames_inout,\n", - " iaxis=0,\n", - " oaxis=0, )\n", - "\n", - " # data reader\n", - " self.reader = LoadInputsAndTargets(\n", - " mode=\"asr\",\n", - " load_output=True,\n", - " preprocess_conf=preprocess_conf,\n", - " preprocess_args={\"train\":\n", - " train_mode}, # Switch the mode of preprocessing\n", - " )\n", - "\n", - " # Setup a converter\n", - " if num_encs == 1:\n", - " self.converter = CustomConverter(\n", - " subsampling_factor=subsampling_factor, dtype=np.float32)\n", - " else:\n", - " assert NotImplementedError(\"not impl CustomConverterMulEnc.\")\n", - "\n", - " # hack to make batchsize argument as 1\n", - " # actual bathsize is included in a list\n", - " # default collate function converts numpy array to pytorch tensor\n", - " # we used an empty collate function instead which returns list\n", - " self.dataset = TransformDataset(self.minibaches, \n", - " lambda data: self.converter([self.reader(data, return_uttid=True)]))\n", - " self.dataloader = DataLoader(\n", - " dataset=self.dataset,\n", - " batch_size=1,\n", - " shuffle=not use_sortagrad if train_mode else False,\n", - " collate_fn=lambda x: x[0],\n", - " num_workers=n_iter_processes, )\n", - "\n", - " def __repr__(self):\n", - " echo = f\"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> \"\n", - " echo += f\"train_mode: {self.train_mode}, \"\n", - " echo += f\"sortagrad: {self.use_sortagrad}, \"\n", - " echo += f\"batch_size: {self.batch_size}, \"\n", - " echo += f\"maxlen_in: {self.maxlen_in}, \"\n", - " echo += f\"maxlen_out: {self.maxlen_out}, \"\n", - " echo += f\"batch_count: {self.batch_count}, \"\n", - " echo += f\"batch_bins: {self.batch_bins}, \"\n", - " echo += f\"batch_frames_in: {self.batch_frames_in}, \"\n", - " echo += f\"batch_frames_out: {self.batch_frames_out}, \"\n", - " echo += f\"batch_frames_inout: {self.batch_frames_inout}, \"\n", - " echo += f\"subsampling_factor: {self.subsampling_factor}, \"\n", - " echo += f\"num_encs: {self.num_encs}, \"\n", - " echo += f\"num_workers: {self.n_iter_processes}, \"\n", - " echo += f\"file: {self.json_file}\"\n", - " return echo\n", - " \n", - " def __len__(self):\n", - " return len(self.dataloader)\n", - " \n", - " def __iter__(self):\n", - " return self.dataloader.__iter__()\n", - " \n", - " def __call__(self):\n", - " return self.__iter__()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 231, - "id": "fcea3fd0", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[INFO 2021/08/18 07:42:23 batchfy.py:399] count is auto detected as seq\n", - "[INFO 2021/08/18 07:42:23 batchfy.py:423] # utts: 5542\n", - "[INFO 2021/08/18 07:42:23 batchfy.py:466] # minibatches: 278\n" - ] - } - ], - "source": [ - "train = BatchDataLoader(dev_data, True, batch_size=20)" - ] - }, - { - "cell_type": "code", - "execution_count": 232, - "id": "e2a2c9a8", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "278\n", - "['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'auto_collate_batch', 'batch_sampler', 'batch_size', 'collate_fn', 'dataset', 'dataset_kind', 'feed_list', 'from_dataset', 'from_generator', 'num_workers', 'pin_memory', 'places', 'return_list', 'timeout', 'use_buffer_reader', 'use_shared_memory', 'worker_init_fn']\n", - "<__main__.BatchDataLoader object at 0x7fdddba35470> train_mode: True, sortagrad: False, batch_size: 20, maxlen_in: inf, maxlen_out: inf, batch_count: auto, batch_bins: 0, batch_frames_in: 0, batch_frames_out: 0, batch_frames_inout: 0, subsampling_factor: 1, num_encs: 1, num_workers: 1, file: /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev\n", - "278\n" - ] - } - ], - "source": [ - "print(len(train.dataloader))\n", - "print(dir(train.dataloader))\n", - "print(train)\n", - "print(len(train))" - ] - }, - { - "cell_type": "code", - "execution_count": 220, - "id": "a5ba7d6e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['7601-101619-0003', '1255-138279-0000', '1272-128104-0004', '6123-59150-0027', '2078-142845-0025', '7850-73752-0018', '4570-24733-0004', '2506-169427-0002', '7601-101619-0004', '3170-137482-0000', '6267-53049-0019', '4570-14911-0009', '174-168635-0018', '7601-291468-0004', '3576-138058-0022', '1919-142785-0007', '6467-62797-0007', '4153-61735-0005', '1686-142278-0003', '2506-169427-0000']\n", - "Tensor(shape=[20, 2961, 83], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[-1.99415934, -1.80315673, -1.88801885, ..., 0.86933994, -0.59853148, 0.02596200],\n", - " [-1.95346808, -1.84891188, -2.17492867, ..., 0.83640492, -0.59853148, -0.11333394],\n", - " [-2.27899861, -2.21495342, -2.58480024, ..., 0.91874266, -0.59853148, -0.31453922],\n", - " ...,\n", - " [-2.64522028, -2.35221887, -2.91269732, ..., 1.48994756, -0.16100442, 0.36646330],\n", - " [-2.40107250, -2.21495342, -2.37986445, ..., 1.44072104, -0.13220564, 0.12656468],\n", - " [-2.15692472, -1.89466715, -2.25690317, ..., 1.31273174, -0.09620714, -0.15202725]],\n", - "\n", - " [[-0.28859532, -0.29033494, -0.86576819, ..., 1.37753224, -0.30570769, 0.25806731],\n", - " [-0.20149794, -0.17814466, -0.59891301, ..., 1.35188794, -0.30570769, -0.02964944],\n", - " [-0.34947991, -0.33597648, -0.96877253, ..., 1.38394332, -0.30570769, -0.38376236],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-0.44914246, -0.33902276, -0.78237975, ..., 1.38218808, 0.29214793, -0.16815147],\n", - " [-0.55490732, -0.41596055, -0.84425378, ..., 1.34530187, 0.25002354, -0.04004869],\n", - " [-0.83694696, -0.62112784, -1.07112527, ..., 1.19160914, 0.20789915, 0.37984371],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.24343657, -0.94188881, -1.41092563, ..., 0.96716309, 0.60345763, 0.15360183],\n", - " [-1.19466043, -0.80585432, -0.49723154, ..., 1.06735480, 0.60345763, 0.14511746],\n", - " [-0.94079566, -0.59330046, -0.40948665, ..., 0.82244170, 0.55614340, 0.28086722],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.21757117, 0.11361472, -0.33262897, ..., 0.76338506, -0.10711290, -0.57754958],\n", - " [-1.00205481, -0.61152041, -0.47124696, ..., 1.11897349, -0.10711290, 0.24931324],\n", - " [-1.03929281, -1.20336759, -1.16433656, ..., 0.88888687, -0.10711290, -0.04115745],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-1.25289667, -1.05046368, -0.82881606, ..., 1.23991334, 0.61702502, 0.05275881],\n", - " [-1.19659519, -0.78677225, -0.80407262, ..., 1.27644968, 0.61702502, -0.35079369],\n", - " [-1.49687004, -1.01750231, -0.82881606, ..., 1.29106426, 0.65006059, 0.17958963],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]]])\n", - "Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [2961, 2948, 2938, 2907, 2904, 2838, 2832, 2819, 2815, 2797, 2775, 2710, 2709, 2696, 2688, 2661, 2616, 2595, 2589, 2576])\n", - "Tensor(shape=[20, 133], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[3098, 1595, 389, ..., -1 , -1 , -1 ],\n", - " [2603, 4832, 482, ..., -1 , -1 , -1 ],\n", - " [2796, 303, 269, ..., -1 , -1 , -1 ],\n", - " ...,\n", - " [3218, 3673, 206, ..., -1 , -1 , -1 ],\n", - " [2371, 4832, 4031, ..., -1 , -1 , -1 ],\n", - " [2570, 2433, 4285, ..., -1 , -1 , -1 ]])\n", - "Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [80 , 83 , 102, 133, 82 , 102, 71 , 91 , 68 , 81 , 86 , 67 , 71 , 95 , 65 , 88 , 97 , 98 , 89 , 72 ])\n" - ] - } - ], - "source": [ - "for batch in train:\n", - " utts, xs, ilens, ys, olens = batch\n", - " print(utts)\n", - " print(xs)\n", - " print(ilens)\n", - " print(ys)\n", - " print(olens)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3c974a1e", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/hack_api_test.ipynb b/.notebook/hack_api_test.ipynb deleted file mode 100644 index f653084e6..000000000 --- a/.notebook/hack_api_test.ipynb +++ /dev/null @@ -1,290 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "breeding-haven", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/home/ssd5/zhanghui/DeepSpeech2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "appropriate-theta", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "LICENSE deepspeech examples\t\t requirements.txt tools\r\n", - "README.md docs\t libsndfile-1.0.28\t setup.sh\t utils\r\n", - "README_cn.md env.sh\t libsndfile-1.0.28.tar.gz tests\r\n" - ] - } - ], - "source": [ - "!ls" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "entire-bloom", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "WARNING:root:override cat of paddle.Tensor if exists or register, remove this when fixed!\n", - "WARNING:root:register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "WARNING:root:register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "WARNING:root:register user repeat to paddle.Tensor, remove this when fixed!\n", - "WARNING:root:register user glu to paddle.nn.functional, remove this when fixed!\n", - "WARNING:root:register user GLU to paddle.nn, remove this when fixed!\n", - "WARNING:root:register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "WARNING:root:override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n" - ] - } - ], - "source": [ - "from deepspeech.modules import loss" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "governmental-aircraft", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "import paddle" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "proprietary-disaster", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - " paddle.VarBase>" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "paddle.Tensor.repeat" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "first-diagram", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "paddle.Tensor.size" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "intelligent-david", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "paddle.Tensor.cat" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "bronze-tenant", - "metadata": {}, - "outputs": [], - "source": [ - "a = paddle.to_tensor([12,32, 10, 12, 123,32 ,4])" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "balanced-bearing", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "7" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.size" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "extreme-republic", - "metadata": {}, - "outputs": [], - "source": [ - "def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:\n", - " nargs = len(args)\n", - " assert (nargs <= 1)\n", - " s = paddle.shape(xs)\n", - " if nargs == 1:\n", - " return s[args[0]]\n", - " else:\n", - " return s\n", - "\n", - "# logger.warn(\n", - "# \"override size of paddle.Tensor if exists or register, remove this when fixed!\"\n", - "# )\n", - "paddle.Tensor.size = size" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "gross-addiction", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [7])" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.size(0)\n", - "a.size()" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "adverse-dining", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [7])" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.size()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "popular-potato", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/jit_infer.ipynb b/.notebook/jit_infer.ipynb deleted file mode 100644 index 20882c1ae..000000000 --- a/.notebook/jit_infer.ipynb +++ /dev/null @@ -1,672 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/home/ssd5/zhanghui/DeepSpeech2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-03-26 02:55:23,873 - WARNING - register user softmax to paddle, remove this when fixed!\n", - "2021-03-26 02:55:23,875 - WARNING - register user sigmoid to paddle, remove this when fixed!\n", - "2021-03-26 02:55:23,875 - WARNING - register user relu to paddle, remove this when fixed!\n", - "2021-03-26 02:55:23,876 - WARNING - override cat of paddle if exists or register, remove this when fixed!\n", - "2021-03-26 02:55:23,876 - WARNING - override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "2021-03-26 02:55:23,877 - WARNING - override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "2021-03-26 02:55:23,877 - WARNING - override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "2021-03-26 02:55:23,878 - WARNING - register user view to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,878 - WARNING - register user view_as to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,879 - WARNING - register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,880 - WARNING - register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,880 - WARNING - register user fill_ to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,881 - WARNING - register user repeat to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,881 - WARNING - register user softmax to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,882 - WARNING - register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,882 - WARNING - register user relu to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,883 - WARNING - register user glu to paddle.nn.functional, remove this when fixed!\n", - "2021-03-26 02:55:23,883 - WARNING - override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "2021-03-26 02:55:23,884 - WARNING - register user GLU to paddle.nn, remove this when fixed!\n", - "2021-03-26 02:55:23,884 - WARNING - register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n" - ] - } - ], - "source": [ - "import os\n", - "import time\n", - "import argparse\n", - "import functools\n", - "import paddle\n", - "import numpy as np\n", - "\n", - "from deepspeech.utils.socket_server import warm_up_test\n", - "from deepspeech.utils.socket_server import AsrTCPServer\n", - "from deepspeech.utils.socket_server import AsrRequestHandler\n", - "\n", - "from deepspeech.training.cli import default_argument_parser\n", - "from deepspeech.exps.deepspeech2.config import get_cfg_defaults\n", - "\n", - "from deepspeech.frontend.utility import read_manifest\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "\n", - "from deepspeech.models.ds2 import DeepSpeech2Model\n", - "from deepspeech.models.ds2 import DeepSpeech2InferModel\n", - "from deepspeech.io.dataset import ManifestDataset\n", - "\n", - "\n", - "\n", - "from deepspeech.frontend.utility import read_manifest" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.0.0\n", - "e7f28d6c0db54eb9c9a810612300b526687e56a6\n", - "OFF\n", - "OFF\n", - "commit: e7f28d6c0db54eb9c9a810612300b526687e56a6\n", - "None\n", - "0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - }, - { - "data": { - "text/plain": [ - "['__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " 'commit',\n", - " 'full_version',\n", - " 'istaged',\n", - " 'major',\n", - " 'minor',\n", - " 'mkl',\n", - " 'patch',\n", - " 'rc',\n", - " 'show',\n", - " 'with_mkl']" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(paddle.__version__)\n", - "print(paddle.version.commit)\n", - "print(paddle.version.with_mkl)\n", - "print(paddle.version.mkl())\n", - "print(paddle.version.show())\n", - "print(paddle.version.patch)\n", - "dir(paddle.version)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "data:\n", - " augmentation_config: conf/augmentation.config\n", - " batch_size: 64\n", - " dev_manifest: data/manifest.dev\n", - " keep_transcription_text: False\n", - " max_duration: 27.0\n", - " max_freq: None\n", - " mean_std_filepath: examples/aishell/data/mean_std.npz\n", - " min_duration: 0.0\n", - " n_fft: None\n", - " num_workers: 0\n", - " random_seed: 0\n", - " shuffle_method: batch_shuffle\n", - " sortagrad: True\n", - " specgram_type: linear\n", - " stride_ms: 10.0\n", - " target_dB: -20\n", - " target_sample_rate: 16000\n", - " test_manifest: examples/aishell/data/manifest.test\n", - " train_manifest: data/manifest.train\n", - " use_dB_normalization: True\n", - " vocab_filepath: examples/aishell/data/vocab.txt\n", - " window_ms: 20.0\n", - "decoding:\n", - " alpha: 2.6\n", - " batch_size: 128\n", - " beam_size: 300\n", - " beta: 5.0\n", - " cutoff_prob: 0.99\n", - " cutoff_top_n: 40\n", - " decoding_method: ctc_beam_search\n", - " error_rate_type: cer\n", - " lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm\n", - " num_proc_bsearch: 10\n", - "model:\n", - " num_conv_layers: 2\n", - " num_rnn_layers: 3\n", - " rnn_layer_size: 1024\n", - " share_rnn_weights: False\n", - " use_gru: True\n", - "training:\n", - " global_grad_clip: 5.0\n", - " lr: 0.0005\n", - " lr_decay: 0.83\n", - " n_epoch: 30\n", - " weight_decay: 1e-06\n", - "----------- Configuration Arguments -----------\n", - "checkpoint_path: examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725\n", - "config: examples/aishell/conf/deepspeech2.yaml\n", - "device: gpu\n", - "dump_config: None\n", - "export_path: None\n", - "host_ip: localhost\n", - "host_port: 8086\n", - "model_dir: None\n", - "model_file: examples/aishell/jit.model.pdmodel\n", - "nprocs: 1\n", - "opts: ['data.test_manifest', 'examples/aishell/data/manifest.test', 'data.mean_std_filepath', 'examples/aishell/data/mean_std.npz', 'data.vocab_filepath', 'examples/aishell/data/vocab.txt']\n", - "output: None\n", - "params_file: examples/aishell/jit.model.pdiparams\n", - "speech_save_dir: demo_cache\n", - "use_gpu: False\n", - "warmup_manifest: examples/aishell/data/manifest.test\n", - "------------------------------------------------\n" - ] - } - ], - "source": [ - "parser = default_argument_parser()\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "add_arg('host_ip', str,\n", - " 'localhost',\n", - " \"Server's IP address.\")\n", - "add_arg('host_port', int, 8086, \"Server's IP port.\")\n", - "add_arg('speech_save_dir', str,\n", - " 'demo_cache',\n", - " \"Directory to save demo audios.\")\n", - "add_arg('warmup_manifest', \n", - " str, \n", - " \"examples/aishell/data/manifest.test\", \n", - " \"Filepath of manifest to warm up.\")\n", - "add_arg(\n", - " \"--model_file\",\n", - " type=str,\n", - " default=\"examples/aishell/jit.model.pdmodel\",\n", - " help=\"Model filename, Specify this when your model is a combined model.\"\n", - ")\n", - "add_arg(\n", - " \"--params_file\",\n", - " type=str,\n", - " default=\"examples/aishell/jit.model.pdiparams\",\n", - " help=\n", - " \"Parameter filename, Specify this when your model is a combined model.\"\n", - ")\n", - "add_arg(\n", - " \"--model_dir\",\n", - " type=str,\n", - " default=None,\n", - " help=\n", - " \"Model dir, If you load a non-combined model, specify the directory of the model.\"\n", - ")\n", - "add_arg(\"--use_gpu\",type=bool,default=False, help=\"Whether use gpu.\")\n", - "\n", - "\n", - "args = parser.parse_args(\n", - " \"--checkpoint_path examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725 --config examples/aishell/conf/deepspeech2.yaml --opts data.test_manifest examples/aishell/data/manifest.test data.mean_std_filepath examples/aishell/data/mean_std.npz data.vocab_filepath examples/aishell/data/vocab.txt\".split()\n", - ")\n", - "\n", - "\n", - "config = get_cfg_defaults()\n", - "if args.config:\n", - " config.merge_from_file(args.config)\n", - "if args.opts:\n", - " config.merge_from_list(args.opts)\n", - "config.freeze()\n", - "print(config)\n", - "\n", - "args.warmup_manifest = config.data.test_manifest\n", - "\n", - "print_arguments(args)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = ManifestDataset(\n", - " config.data.test_manifest,\n", - " config.data.unit_type,\n", - " config.data.vocab_filepath,\n", - " config.data.mean_std_filepath,\n", - " augmentation_config=\"{}\",\n", - " max_duration=config.data.max_duration,\n", - " min_duration=config.data.min_duration,\n", - " stride_ms=config.data.stride_ms,\n", - " window_ms=config.data.window_ms,\n", - " n_fft=config.data.n_fft,\n", - " max_freq=config.data.max_freq,\n", - " target_sample_rate=config.data.target_sample_rate,\n", - " specgram_type=config.data.specgram_type,\n", - " feat_dim=config.data.feat_dim,\n", - " delta_delta=config.data.delat_delta,\n", - " use_dB_normalization=config.data.use_dB_normalization,\n", - " target_dB=config.data.target_dB,\n", - " random_seed=config.data.random_seed,\n", - " keep_transcription_text=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-03-26 02:55:57,930 - INFO - [checkpoint] Rank 0: loaded model from examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725.pdparams\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "layer summary:\n", - "encoder.conv.conv_in.conv.weight|[32, 1, 41, 11]|14432\n", - "encoder.conv.conv_in.bn.weight|[32]|32\n", - "encoder.conv.conv_in.bn.bias|[32]|32\n", - "encoder.conv.conv_in.bn._mean|[32]|32\n", - "encoder.conv.conv_in.bn._variance|[32]|32\n", - "encoder.conv.conv_stack.0.conv.weight|[32, 32, 21, 11]|236544\n", - "encoder.conv.conv_stack.0.bn.weight|[32]|32\n", - "encoder.conv.conv_stack.0.bn.bias|[32]|32\n", - "encoder.conv.conv_stack.0.bn._mean|[32]|32\n", - "encoder.conv.conv_stack.0.bn._variance|[32]|32\n", - "encoder.rnn.rnn_stacks.0.fw_fc.weight|[1312, 3072]|4030464\n", - "encoder.rnn.rnn_stacks.0.fw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_fc.weight|[1312, 3072]|4030464\n", - "encoder.rnn.rnn_stacks.0.bw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.fw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.bw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.fw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.bw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.1.fw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.1.bw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.fw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.bw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.fw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.bw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.2.fw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.2.bw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.fw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.bw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.fw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.bw_rnn.cell.bias_hh|[3072]|3072\n", - "decoder.ctc_lo.weight|[2048, 4300]|8806400\n", - "decoder.ctc_lo.bias|[4300]|4300\n", - "layer has 66 parameters, 80148012 elements.\n" - ] - } - ], - "source": [ - "model = DeepSpeech2InferModel.from_pretrained(dataset, config,\n", - " args.checkpoint_path)\n", - "model.eval()" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "examples/aishell/jit.model.pdmodel\n", - "examples/aishell/jit.model.pdiparams\n", - "0\n", - "False\n" - ] - } - ], - "source": [ - "\n", - "from paddle.inference import Config\n", - "from paddle.inference import PrecisionType\n", - "from paddle.inference import create_predictor\n", - "\n", - "args.use_gpu=False\n", - "paddle.set_device('cpu')\n", - "\n", - "def init_predictor(args):\n", - " if args.model_dir is not None:\n", - " config = Config(args.model_dir)\n", - " else:\n", - " config = Config(args.model_file, args.params_file)\n", - "\n", - " if args.use_gpu:\n", - " config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)\n", - "# config.enable_tensorrt_engine(precision_mode=PrecisionType.Float32,\n", - "# use_calib_mode=True) # 开启TensorRT预测,精度为fp32,开启int8离线量化\n", - " else:\n", - " # If not specific mkldnn, you can set the blas thread.\n", - " # The thread num should not be greater than the number of cores in the CPU.\n", - " config.set_cpu_math_library_num_threads(1)\n", - " config.enable_mkldnn()\n", - " \n", - " config.enable_memory_optim()\n", - " config.switch_ir_optim(True)\n", - " \n", - " print(config.model_dir())\n", - " print(config.prog_file())\n", - " print(config.params_file())\n", - " print(config.gpu_device_id())\n", - " print(args.use_gpu)\n", - " predictor = create_predictor(config)\n", - " return predictor\n", - "\n", - "def run(predictor, audio, audio_len):\n", - " # copy img data to input tensor\n", - " input_names = predictor.get_input_names()\n", - " for i, name in enumerate(input_names):\n", - " print(\"input:\", i, name)\n", - " \n", - " audio_tensor = predictor.get_input_handle('audio')\n", - " audio_tensor.reshape(audio.shape)\n", - " audio_tensor.copy_from_cpu(audio.copy())\n", - " \n", - " audiolen_tensor = predictor.get_input_handle('audio_len')\n", - " audiolen_tensor.reshape(audio_len.shape)\n", - " audiolen_tensor.copy_from_cpu(audio_len.copy())\n", - "\n", - " output_names = predictor.get_output_names()\n", - " for i, name in enumerate(output_names):\n", - " print(\"output:\", i, name)\n", - "\n", - " # do the inference\n", - " predictor.run()\n", - "\n", - " results = []\n", - " # get out data from output tensor\n", - " output_names = predictor.get_output_names()\n", - " for i, name in enumerate(output_names):\n", - " output_tensor = predictor.get_output_handle(name)\n", - " output_data = output_tensor.copy_to_cpu()\n", - " results.append(output_data)\n", - "\n", - " return results\n", - "\n", - "\n", - "predictor = init_predictor(args)\n", - "\n", - "def file_to_transcript(filename):\n", - " print(filename)\n", - " feature = dataset.process_utterance(filename, \"\")\n", - " audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n", - " audio_len = feature[0].shape[1]\n", - " audio_len = np.array([audio_len]).astype('int64') # [1]\n", - " \n", - " \n", - " i_probs = run(predictor, audio, audio_len)\n", - " print('jit:', i_probs[0], type(i_probs[0]))\n", - " \n", - " audio = paddle.to_tensor(audio)\n", - " audio_len = paddle.to_tensor(audio_len)\n", - " print(audio.shape)\n", - " print(audio_len.shape)\n", - " \n", - " #eouts, eouts_len = model.encoder(audio, audio_len)\n", - " #probs = model.decoder.softmax(eouts)\n", - " probs = model.forward(audio, audio_len)\n", - " print('paddle:', probs.numpy())\n", - " \n", - " flag = np.allclose(i_probs[0], probs.numpy())\n", - " print(flag)\n", - " \n", - " return probs\n", - "\n", - "# result_transcript = model.decode(\n", - "# audio,\n", - "# audio_len,\n", - "# vocab_list=dataset.vocab_list,\n", - "# decoding_method=config.decoding.decoding_method,\n", - "# lang_model_path=config.decoding.lang_model_path,\n", - "# beam_alpha=config.decoding.alpha,\n", - "# beam_beta=config.decoding.beta,\n", - "# beam_size=config.decoding.beam_size,\n", - "# cutoff_prob=config.decoding.cutoff_prob,\n", - "# cutoff_top_n=config.decoding.cutoff_top_n,\n", - "# num_processes=config.decoding.num_proc_bsearch)\n", - "# return result_transcript[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Warm-up Test Case %d: %s 0 /home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n", - "/home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n", - "input: 0 audio\n", - "input: 1 audio_len\n", - "output: 0 tmp_75\n", - "jit: [[[8.91786298e-12 4.45648032e-12 3.67572750e-09 ... 8.91767563e-12\n", - " 8.91573707e-12 4.64317296e-08]\n", - " [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n", - " 1.55891342e-15 9.99992609e-01]\n", - " [1.24638127e-17 7.61802427e-16 2.93265812e-14 ... 1.24633371e-17\n", - " 1.24587264e-17 1.00000000e+00]\n", - " ...\n", - " [4.37488240e-15 2.43676260e-12 1.98770514e-12 ... 4.37479896e-15\n", - " 4.37354747e-15 1.00000000e+00]\n", - " [3.89334696e-13 1.66754856e-11 1.42900388e-11 ... 3.89329492e-13\n", - " 3.89252270e-13 1.00000000e+00]\n", - " [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n", - " 1.00334095e-10 9.99998808e-01]]] \n", - "[1, 161, 522]\n", - "[1]\n", - "paddle: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n", - " 8.91577090e-12 4.64319072e-08]\n", - " [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n", - " 1.55891342e-15 9.99992609e-01]\n", - " [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n", - " 1.24587735e-17 1.00000000e+00]\n", - " ...\n", - " [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n", - " 4.37354747e-15 1.00000000e+00]\n", - " [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n", - " 3.89253761e-13 1.00000000e+00]\n", - " [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n", - " 1.00334095e-10 9.99998808e-01]]]\n", - "False\n" - ] - } - ], - "source": [ - "manifest = read_manifest(args.warmup_manifest)\n", - "\n", - "for idx, sample in enumerate(manifest[:1]):\n", - " print(\"Warm-up Test Case %d: %s\", idx, sample['audio_filepath'])\n", - " start_time = time.time()\n", - " transcript = file_to_transcript(sample['audio_filepath'])\n", - " finish_time = time.time()\n", - "# print(\"Response Time: %f, Transcript: %s\" %\n", - "# (finish_time - start_time, transcript))\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 161, 522) (1,)\n", - "input: 0 audio\n", - "input: 1 audio_len\n", - "output: 0 tmp_75\n", - "jit: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n", - " 8.91577090e-12 4.64319072e-08]\n", - " [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n", - " 1.55891342e-15 9.99992609e-01]\n", - " [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n", - " 1.24587735e-17 1.00000000e+00]\n", - " ...\n", - " [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n", - " 4.37354747e-15 1.00000000e+00]\n", - " [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n", - " 3.89253761e-13 1.00000000e+00]\n", - " [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n", - " 1.00334095e-10 9.99998808e-01]]]\n" - ] - } - ], - "source": [ - "def test(filename):\n", - " feature = dataset.process_utterance(filename, \"\")\n", - " audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n", - " audio_len = feature[0].shape[1]\n", - " audio_len = np.array([audio_len]).astype('int64') # [1]\n", - " \n", - " print(audio.shape, audio_len.shape)\n", - "\n", - " i_probs = run(predictor, audio, audio_len)\n", - " print('jit:', i_probs[0])\n", - " return i_probs\n", - " \n", - "probs = test(sample['audio_filepath'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/.notebook/layer_norm_test.ipynb b/.notebook/layer_norm_test.ipynb deleted file mode 100644 index eac3566ff..000000000 --- a/.notebook/layer_norm_test.ipynb +++ /dev/null @@ -1,229 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 32, - "id": "academic-surname", - "metadata": {}, - "outputs": [], - "source": [ - "import paddle\n", - "from paddle import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "fundamental-treasure", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameter containing:\n", - "Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])\n", - "Parameter containing:\n", - "Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])\n" - ] - } - ], - "source": [ - "L = nn.LayerNorm(256, epsilon=1e-12)\n", - "for p in L.parameters():\n", - " print(p)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "consolidated-elephant", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "moderate-noise", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float64\n" - ] - } - ], - "source": [ - "x = np.random.randn(2, 51, 256)\n", - "print(x.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "cooked-progressive", - "metadata": {}, - "outputs": [], - "source": [ - "y = L(paddle.to_tensor(x, dtype='float32'))" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "optimum-milwaukee", - "metadata": {}, - "outputs": [], - "source": [ - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "viral-indian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameter containing:\n", - "tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1.], requires_grad=True)\n", - "Parameter containing:\n", - "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " requires_grad=True)\n" - ] - } - ], - "source": [ - "TL = torch.nn.LayerNorm(256, eps=1e-12)\n", - "for p in TL.parameters():\n", - " print(p)" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "skilled-vietnamese", - "metadata": {}, - "outputs": [], - "source": [ - "ty = TL(torch.tensor(x, dtype=torch.float32))" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "id": "incorrect-allah", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.allclose(y.numpy(), ty.detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "prostate-cameroon", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 52, - "id": "governmental-surge", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = np.random.randn(2, 256)\n", - "y = L(paddle.to_tensor(x, dtype='float32'))\n", - "ty = TL(torch.tensor(x, dtype=torch.float32))\n", - "np.allclose(y.numpy(), ty.detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "confidential-jacket", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/mask_and_masked_fill_test.ipynb b/.notebook/mask_and_masked_fill_test.ipynb deleted file mode 100644 index 265ec536b..000000000 --- a/.notebook/mask_and_masked_fill_test.ipynb +++ /dev/null @@ -1,449 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "primary-organic", - "metadata": {}, - "outputs": [], - "source": [ - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "stopped-semester", - "metadata": {}, - "outputs": [], - "source": [ - "def mask_finished_scores(score: torch.Tensor,\n", - " flag: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"\n", - " If a sequence is finished, we only allow one alive branch. This function\n", - " aims to give one branch a zero score and the rest -inf score.\n", - " Args:\n", - " score (torch.Tensor): A real value array with shape\n", - " (batch_size * beam_size, beam_size).\n", - " flag (torch.Tensor): A bool array with shape\n", - " (batch_size * beam_size, 1).\n", - " Returns:\n", - " torch.Tensor: (batch_size * beam_size, beam_size).\n", - " \"\"\"\n", - " beam_size = score.size(-1)\n", - " zero_mask = torch.zeros_like(flag, dtype=torch.bool)\n", - " if beam_size > 1:\n", - " unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),\n", - " dim=1)\n", - " finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),\n", - " dim=1)\n", - " else:\n", - " unfinished = zero_mask\n", - " finished = flag\n", - " print(unfinished)\n", - " print(finished)\n", - " score.masked_fill_(unfinished, -float('inf'))\n", - " score.masked_fill_(finished, 0)\n", - " return score" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "agreed-portuguese", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[ True],\n", - " [False]])\n", - "tensor([[-0.8841, 0.7381, -0.9986],\n", - " [ 0.2675, -0.7971, 0.3798]])\n", - "tensor([[ True, True],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "score = torch.randn((2, 3))\n", - "flag = torch.ones((2, 1), dtype=torch.bool)\n", - "flag[1] = False\n", - "print(flag)\n", - "print(score)\n", - "print(flag.repeat([1, 2]))" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "clean-aspect", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[False, True, True],\n", - " [False, False, False]])\n", - "tensor([[ True, False, False],\n", - " [False, False, False]])\n", - "tensor([[ 0.0000, -inf, -inf],\n", - " [ 0.2675, -0.7971, 0.3798]])\n", - "tensor([[ 0.0000, -inf, -inf],\n", - " [ 0.2675, -0.7971, 0.3798]])\n" - ] - } - ], - "source": [ - "r = mask_finished_scores(score, flag)\n", - "print(r)\n", - "print(score)" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "id": "thrown-airline", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[2, 1], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True ],\n", - " [False]])\n", - "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, 1.87704289, 0.01988174],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , True ],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "import paddle\n", - "\n", - "score = paddle.randn((2, 3))\n", - "flag = paddle.ones((2, 1), dtype='bool')\n", - "flag[1] = False\n", - "print(flag)\n", - "print(score)\n", - "print(flag.tile([1, 2]))" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "internal-patent", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[False, True , True ],\n", - " [False, False, False]])\n", - "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , False, False],\n", - " [False, False, False]])\n", - "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, 1.87704289, 0.01988174],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, 1.87704289, 0.01988174],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 0. , -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 0. , -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n" - ] - } - ], - "source": [ - "paddle.bool = 'bool'\n", - "\n", - "def masked_fill(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n", - " print(xs)\n", - " trues = paddle.ones_like(xs) * value\n", - " assert xs.shape == mask.shape\n", - " xs = paddle.where(mask, trues, xs)\n", - " return xs\n", - "\n", - "def masked_fill_(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n", - " print('x', xs)\n", - " trues = paddle.ones_like(xs) * value\n", - " assert xs.shape == mask.shape\n", - " ret = paddle.where(mask, trues, xs)\n", - " print('2', xs)\n", - " paddle.assign(ret, output=xs)\n", - " print('3', xs)\n", - "\n", - "paddle.Tensor.masked_fill = masked_fill\n", - "paddle.Tensor.masked_fill_ = masked_fill_\n", - "\n", - "def mask_finished_scores_pd(score: paddle.Tensor,\n", - " flag: paddle.Tensor) -> paddle.Tensor:\n", - " \"\"\"\n", - " If a sequence is finished, we only allow one alive branch. This function\n", - " aims to give one branch a zero score and the rest -inf score.\n", - " Args:\n", - " score (torch.Tensor): A real value array with shape\n", - " (batch_size * beam_size, beam_size).\n", - " flag (torch.Tensor): A bool array with shape\n", - " (batch_size * beam_size, 1).\n", - " Returns:\n", - " torch.Tensor: (batch_size * beam_size, beam_size).\n", - " \"\"\"\n", - " beam_size = score.shape[-1]\n", - " zero_mask = paddle.zeros_like(flag, dtype=paddle.bool)\n", - " if beam_size > 1:\n", - " unfinished = paddle.concat((zero_mask, flag.tile([1, beam_size - 1])),\n", - " axis=1)\n", - " finished = paddle.concat((flag, zero_mask.tile([1, beam_size - 1])),\n", - " axis=1)\n", - " else:\n", - " unfinished = zero_mask\n", - " finished = flag\n", - " print(unfinished)\n", - " print(finished)\n", - " \n", - " #score.masked_fill_(unfinished, -float('inf'))\n", - " #score.masked_fill_(finished, 0)\n", - "# infs = paddle.ones_like(score) * -float('inf')\n", - "# score = paddle.where(unfinished, infs, score)\n", - "# score = paddle.where(finished, paddle.zeros_like(score), score)\n", - "\n", - "# score = score.masked_fill(unfinished, -float('inf'))\n", - "# score = score.masked_fill(finished, 0)\n", - " score.masked_fill_(unfinished, -float('inf'))\n", - " score.masked_fill_(finished, 0)\n", - " return score\n", - "\n", - "r = mask_finished_scores_pd(score, flag)\n", - "print(r)" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "vocal-prime", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 57, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "score.value" - ] - }, - { - "cell_type": "code", - "execution_count": 71, - "id": "bacterial-adolescent", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Union, Any" - ] - }, - { - "cell_type": "code", - "execution_count": 72, - "id": "absent-fiber", - "metadata": {}, - "outputs": [], - "source": [ - "def repeat(xs : paddle.Tensor, *size: Any):\n", - " print(size)\n", - " return paddle.tile(xs, size)\n", - "paddle.Tensor.repeat = repeat" - ] - }, - { - "cell_type": "code", - "execution_count": 73, - "id": "material-harbor", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 2)\n", - "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , True ],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "flag = paddle.ones((2, 1), dtype='bool')\n", - "flag[1] = False\n", - "print(flag.repeat(1, 2))" - ] - }, - { - "cell_type": "code", - "execution_count": 84, - "id": "acute-brighton", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [1]), 2)\n", - "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , True ],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "flag = paddle.ones((2, 1), dtype='bool')\n", - "flag[1] = False\n", - "print(flag.repeat(paddle.to_tensor(1), 2))" - ] - }, - { - "cell_type": "code", - "execution_count": 85, - "id": "european-rugby", - "metadata": {}, - "outputs": [], - "source": [ - "def size(xs, *args: int):\n", - " nargs = len(args)\n", - " s = paddle.shape(xs)\n", - " assert(nargs <= 1)\n", - " if nargs == 1:\n", - " return s[args[0]]\n", - " else:\n", - " return s\n", - "paddle.Tensor.size = size" - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "id": "moral-special", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[2], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [2, 1])" - ] - }, - "execution_count": 86, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "flag.size()" - ] - }, - { - "cell_type": "code", - "execution_count": 87, - "id": "ahead-coach", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [1])" - ] - }, - "execution_count": 87, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "flag.size(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 88, - "id": "incomplete-fitness", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [2])" - ] - }, - "execution_count": 88, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "flag.size(0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "upset-connectivity", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/position_embeding_check.ipynb b/.notebook/position_embeding_check.ipynb deleted file mode 100644 index d4b9098d9..000000000 --- a/.notebook/position_embeding_check.ipynb +++ /dev/null @@ -1,231 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "id": "designing-borough", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n", - " 0.0000000e+00 0.0000000e+00]\n", - " [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n", - " 1.1547816e-04 1.0746076e-04]\n", - " [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n", - " 2.3095631e-04 2.1492151e-04]\n", - " ...\n", - " [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n", - " 1.1201146e-02 1.0423505e-02]\n", - " [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n", - " 1.1316618e-02 1.0530960e-02]\n", - " [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n", - " 1.1432089e-02 1.0638415e-02]]\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "import torch\n", - "import math\n", - "import numpy as np\n", - "\n", - "max_len=100\n", - "d_model=256\n", - "\n", - "pe = torch.zeros(max_len, d_model)\n", - "position = torch.arange(0, max_len,\n", - " dtype=torch.float32).unsqueeze(1)\n", - "toruch_position = position\n", - "div_term = torch.exp(\n", - " torch.arange(0, d_model, 2, dtype=torch.float32) *\n", - " -(math.log(10000.0) / d_model))\n", - "tourch_div_term = div_term.cpu().detach().numpy()\n", - "\n", - "\n", - "\n", - "torhc_sin = torch.sin(position * div_term)\n", - "torhc_cos = torch.cos(position * div_term)\n", - "print(torhc_sin.cpu().detach().numpy())\n", - "np_sin = np.sin((position * div_term).cpu().detach().numpy())\n", - "np_cos = np.cos((position * div_term).cpu().detach().numpy())\n", - "print(np.allclose(np_sin, torhc_sin.cpu().detach().numpy()))\n", - "print(np.allclose(np_cos, torhc_cos.cpu().detach().numpy()))\n", - "pe[:, 0::2] = torhc_sin\n", - "pe[:, 1::2] = torhc_cos\n", - "tourch_pe = pe.cpu().detach().numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "swiss-referral", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "False\n", - "False\n", - "False\n", - "False\n", - "[[ 1. 1. 1. ... 1. 1.\n", - " 1. ]\n", - " [ 0.5403023 0.59737533 0.6479059 ... 1. 1.\n", - " 1. ]\n", - " [-0.41614684 -0.28628543 -0.1604359 ... 0.99999994 1.\n", - " 1. ]\n", - " ...\n", - " [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.99993724\n", - " 0.9999457 ]\n", - " [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n", - " 0.99994457]\n", - " [ 0.03982088 -0.52298605 -0.6157435 ... 0.99992454 0.9999347\n", - " 0.99994344]]\n", - "----\n", - "[[ 1. 1. 1. ... 1. 1.\n", - " 1. ]\n", - " [ 0.54030234 0.59737533 0.6479059 ... 1. 1.\n", - " 1. ]\n", - " [-0.41614684 -0.28628543 -0.1604359 ... 1. 1.\n", - " 1. ]\n", - " ...\n", - " [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.9999373\n", - " 0.9999457 ]\n", - " [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n", - " 0.99994457]\n", - " [ 0.03982088 -0.5229861 -0.6157435 ... 0.99992454 0.9999347\n", - " 0.99994344]]\n", - ")))))))\n", - "[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n", - " 0.0000000e+00 0.0000000e+00]\n", - " [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n", - " 1.1547816e-04 1.0746076e-04]\n", - " [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n", - " 2.3095631e-04 2.1492151e-04]\n", - " ...\n", - " [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n", - " 1.1201146e-02 1.0423505e-02]\n", - " [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n", - " 1.1316618e-02 1.0530960e-02]\n", - " [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n", - " 1.1432089e-02 1.0638415e-02]]\n", - "----\n", - "[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n", - " 0.0000000e+00 0.0000000e+00]\n", - " [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n", - " 1.1547816e-04 1.0746076e-04]\n", - " [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n", - " 2.3095631e-04 2.1492151e-04]\n", - " ...\n", - " [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n", - " 1.1201146e-02 1.0423505e-02]\n", - " [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n", - " 1.1316618e-02 1.0530960e-02]\n", - " [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n", - " 1.1432089e-02 1.0638415e-02]]\n" - ] - } - ], - "source": [ - "import paddle\n", - "paddle.set_device('cpu')\n", - "ppe = paddle.zeros((max_len, d_model), dtype='float32')\n", - "position = paddle.arange(0, max_len,\n", - " dtype='float32').unsqueeze(1)\n", - "print(np.allclose(position.numpy(), toruch_position))\n", - "div_term = paddle.exp(\n", - " paddle.arange(0, d_model, 2, dtype='float32') *\n", - " -(math.log(10000.0) / d_model))\n", - "print(np.allclose(div_term.numpy(), tourch_div_term))\n", - "\n", - "\n", - "\n", - "p_sin = paddle.sin(position * div_term)\n", - "p_cos = paddle.cos(position * div_term)\n", - "print(np.allclose(np_sin, p_sin.numpy(), rtol=1.e-6, atol=0))\n", - "print(np.allclose(np_cos, p_cos.numpy(), rtol=1.e-6, atol=0))\n", - "ppe[:, 0::2] = p_sin\n", - "ppe[:, 1::2] = p_cos\n", - "print(np.allclose(p_sin.numpy(), torhc_sin.cpu().detach().numpy()))\n", - "print(np.allclose(p_cos.numpy(), torhc_cos.cpu().detach().numpy()))\n", - "print(p_cos.numpy())\n", - "print(\"----\")\n", - "print(torhc_cos.cpu().detach().numpy())\n", - "print(\")))))))\")\n", - "print(p_sin.numpy())\n", - "print(\"----\")\n", - "print(torhc_sin.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "integrated-boards", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "False\n" - ] - } - ], - "source": [ - "print(np.allclose(ppe.numpy(), pe.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "flying-reserve", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "revised-divide", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/python_test.ipynb b/.notebook/python_test.ipynb deleted file mode 100644 index 819d4c48f..000000000 --- a/.notebook/python_test.ipynb +++ /dev/null @@ -1,1680 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "choice-lender", - "metadata": {}, - "outputs": [], - "source": [ - "eng=\"one minute a voice said and the time buzzer sounded\"\n", - "chn=\"可控是病毒武器最基本的要求\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "ruled-kuwait", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "o\n", - "n\n", - "e\n", - " \n", - "m\n", - "i\n", - "n\n", - "u\n", - "t\n", - "e\n", - " \n", - "a\n", - " \n", - "v\n", - "o\n", - "i\n", - "c\n", - "e\n", - " \n", - "s\n", - "a\n", - "i\n", - "d\n", - " \n", - "a\n", - "n\n", - "d\n", - " \n", - "t\n", - "h\n", - "e\n", - " \n", - "t\n", - "i\n", - "m\n", - "e\n", - " \n", - "b\n", - "u\n", - "z\n", - "z\n", - "e\n", - "r\n", - " \n", - "s\n", - "o\n", - "u\n", - "n\n", - "d\n", - "e\n", - "d\n" - ] - } - ], - "source": [ - "for char in eng:\n", - " print(char)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "passive-petite", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "可\n", - "控\n", - "是\n", - "病\n", - "毒\n", - "武\n", - "器\n", - "最\n", - "基\n", - "本\n", - "的\n", - "要\n", - "求\n" - ] - } - ], - "source": [ - "for char in chn:\n", - " print(char)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "olympic-realtor", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "one\n", - "minute\n", - "a\n", - "voice\n", - "said\n", - "and\n", - "the\n", - "time\n", - "buzzer\n", - "sounded\n" - ] - } - ], - "source": [ - "for word in eng.split():\n", - " print(word)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "induced-enhancement", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "可控是病毒武器最基本的要求\n" - ] - } - ], - "source": [ - "for word in chn.split():\n", - " print(word)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "lovely-bottle", - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'StringIO'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mStringIO\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'StringIO'" - ] - } - ], - "source": [ - "import StringIO" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "interested-cardiff", - "metadata": {}, - "outputs": [], - "source": [ - "from io import StringIO" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "portable-ivory", - "metadata": {}, - "outputs": [], - "source": [ - "inputs = StringIO()" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "compatible-destination", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "64" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "inputs.write(\"nor is mister quilter's manner less interesting than his matter\" + '\\n')" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "federal-margin", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nor is mister quilter's manner less interesting than his matternor is mister quilter's manner less interesting than his matter\n", - "\n" - ] - } - ], - "source": [ - "print(inputs.getvalue())" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "consecutive-entity", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "64" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "inputs.write(\"nor is mister quilter's manner less interesting than his matter\" + '\\n')" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "desirable-anxiety", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nor is mister quilter's manner less interesting than his matternor is mister quilter's manner less interesting than his matter\n", - "nor is mister quilter's manner less interesting than his matter\n", - "\n" - ] - } - ], - "source": [ - "print(inputs.getvalue())" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "employed-schedule", - "metadata": {}, - "outputs": [], - "source": [ - "import tempfile" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "unlikely-honduras", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['__class__', '__del__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__ne__', '__new__', '__next__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_checkClosed', '_checkReadable', '_checkSeekable', '_checkWritable', '_dealloc_warn', '_finalizing', 'close', 'closed', 'detach', 'fileno', 'flush', 'isatty', 'mode', 'name', 'peek', 'raw', 'read', 'read1', 'readable', 'readinto', 'readinto1', 'readline', 'readlines', 'seek', 'seekable', 'tell', 'truncate', 'writable', 'write', 'writelines']\n", - "57\n" - ] - } - ], - "source": [ - "with tempfile.TemporaryFile() as fp:\n", - " print(dir(fp))\n", - " print(fp.name)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "needed-trail", - "metadata": {}, - "outputs": [], - "source": [ - "a = tempfile.mkstemp(suffix=None, prefix='test', dir=None, text=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "hazardous-choir", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['__add__', '__class__', '__contains__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getnewargs__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__mul__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__rmul__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'count', 'index']\n" - ] - } - ], - "source": [ - "print(dir(a))" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "front-sauce", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(57, '/tmp/test27smzbzc')\n" - ] - } - ], - "source": [ - "print(a)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "shared-wages", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "print(a.index)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "charged-carnival", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_closer', 'close', 'delete', 'file', 'name']\n", - "/tmp/tmpfjn7mygy\n" - ] - } - ], - "source": [ - "fp= tempfile.NamedTemporaryFile(mode='w', delete=False)\n", - "print(dir(fp))\n", - "print(fp.name)\n", - "fp.close()" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "religious-terror", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/tmp/tmpfjn7mygy\n" - ] - } - ], - "source": [ - "import os\n", - "os.path.exists(fp.name)\n", - "print(fp.name)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "communist-gospel", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fp.write" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "simplified-clarity", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'example'" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "s='/home/ubuntu/python/example.py'\n", - "os.path.splitext(os.path.basename(s))[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "popular-genius", - "metadata": {}, - "outputs": [], - "source": [ - "from collections import Counter" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "studied-burner", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_items([('hello', 1), ('world', 1)])\n" - ] - } - ], - "source": [ - "counter = Counter()\n", - "counter.update([\"hello\"])\n", - "counter.update([\"world\"])\n", - "print(counter.items())" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "mineral-ceremony", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_items([('h', 1), ('e', 1), ('l', 3), ('o', 2), ('w', 1), ('r', 1), ('d', 1)])\n" - ] - } - ], - "source": [ - "counter = Counter()\n", - "counter.update(\"hello\")\n", - "counter.update(\"world\")\n", - "print(counter.items())" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "nonprofit-freedom", - "metadata": {}, - "outputs": [], - "source": [ - "counter.update(list(\"hello\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "extended-methodology", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_items([('h', 2), ('e', 2), ('l', 5), ('o', 3), ('w', 1), ('r', 1), ('d', 1)])\n" - ] - } - ], - "source": [ - "print(counter.items())" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "grand-benjamin", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['h', 'e', 'l', 'l', 'o']" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "list(\"hello\")" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "id": "marine-fundamentals", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{}\n" - ] - } - ], - "source": [ - "from io import StringIO\n", - "a = StringIO(initial_value='{}', newline='')\n", - "print(a.read())" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "suitable-charlotte", - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "expected str, bytes or os.PathLike object, not _io.StringIO", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mTypeError\u001b[0m: expected str, bytes or os.PathLike object, not _io.StringIO" - ] - } - ], - "source": [ - "with io.open(a) as f:\n", - " print(f.read())" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "institutional-configuration", - "metadata": {}, - "outputs": [], - "source": [ - "io.open?" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "pregnant-modem", - "metadata": {}, - "outputs": [], - "source": [ - "def get_default_args(fn):\n", - " if fn is None:\n", - " return {}\n", - "\n", - " signature = inspect.signature(fn)\n", - " return {\n", - " k: v.default\n", - " for k, v in signature.parameters.items()\n", - " if v.default is not inspect.Parameter.empty\n", - " }" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "first-release", - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'inspect' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_default_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m\u001b[0m in \u001b[0;36mget_default_args\u001b[0;34m(fn)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0msignature\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minspect\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msignature\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m return {\n\u001b[1;32m 7\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefault\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'inspect' is not defined" - ] - } - ], - "source": [ - "get_default_args(io.open)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "convertible-roulette", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: sox in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (1.4.1)\n", - "Requirement already satisfied: numpy>=1.9.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from sox) (1.20.1)\n", - "Requirement already satisfied: librosa in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (0.8.0)\n", - "Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.24.1)\n", - "Requirement already satisfied: numba>=0.43.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.52.0)\n", - "Requirement already satisfied: pooch>=1.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.3.0)\n", - "Requirement already satisfied: scipy>=1.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.2.1)\n", - "Requirement already satisfied: numpy>=1.15.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.20.1)\n", - "Requirement already satisfied: decorator>=3.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (4.4.2)\n", - "Requirement already satisfied: resampy>=0.2.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.2.2)\n", - "Requirement already satisfied: audioread>=2.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (2.1.9)\n", - "Requirement already satisfied: soundfile>=0.9.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.9.0.post1)\n", - "Requirement already satisfied: joblib>=0.14 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.0.1)\n", - "Requirement already satisfied: setuptools in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from numba>=0.43.0->librosa) (51.0.0)\n", - "Requirement already satisfied: llvmlite<0.36,>=0.35.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from numba>=0.43.0->librosa) (0.35.0)\n", - "Requirement already satisfied: appdirs in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (1.4.4)\n", - "Requirement already satisfied: packaging in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (20.9)\n", - "Requirement already satisfied: requests in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (2.25.1)\n", - "Requirement already satisfied: six>=1.3 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from resampy>=0.2.2->librosa) (1.15.0)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa) (2.1.0)\n", - "Requirement already satisfied: cffi>=0.6 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from soundfile>=0.9.0->librosa) (1.14.4)\n", - "Requirement already satisfied: pycparser in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from cffi>=0.6->soundfile>=0.9.0->librosa) (2.20)\n", - "Requirement already satisfied: pyparsing>=2.0.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from packaging->pooch>=1.0->librosa) (2.4.7)\n", - "Requirement already satisfied: idna<3,>=2.5 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (2.10)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (2020.12.5)\n", - "Requirement already satisfied: chardet<5,>=3.0.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (4.0.0)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (1.26.3)\n" - ] - } - ], - "source": [ - "!pip install sox\n", - "!pip install librosa" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "cutting-fleece", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "import sox\n", - "tfm = sox.Transformer()\n", - "sample_rate = 44100\n", - "y = np.sin(2 * np.pi * 440.0 * np.arange(sample_rate * 1.0) / sample_rate)\n", - "print(y.dtype.type)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "historical-diving", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.06264832 0.12505052 ... -0.18696144 -0.12505052\n", - " -0.06264832]\n" - ] - } - ], - "source": [ - "output_array = tfm.build_array(input_array=y, sample_rate_in=sample_rate)\n", - "print(output_array)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "similar-spice", - "metadata": {}, - "outputs": [], - "source": [ - "tfm.build_array?" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "grand-influence", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['8svx', 'aif', 'aifc', 'aiff', 'aiffc', 'al', 'amb', 'amr-nb', 'amr-wb', 'anb', 'au', 'avr', 'awb', 'caf', 'cdda', 'cdr', 'cvs', 'cvsd', 'cvu', 'dat', 'dvms', 'f32', 'f4', 'f64', 'f8', 'fap', 'flac', 'fssd', 'gsm', 'gsrt', 'hcom', 'htk', 'ima', 'ircam', 'la', 'lpc', 'lpc10', 'lu', 'mat', 'mat4', 'mat5', 'maud', 'nist', 'ogg', 'paf', 'prc', 'pvf', 'raw', 's1', 's16', 's2', 's24', 's3', 's32', 's4', 's8', 'sb', 'sd2', 'sds', 'sf', 'sl', 'sln', 'smp', 'snd', 'sndfile', 'sndr', 'sndt', 'sou', 'sox', 'sph', 'sw', 'txw', 'u1', 'u16', 'u2', 'u24', 'u3', 'u32', 'u4', 'u8', 'ub', 'ul', 'uw', 'vms', 'voc', 'vorbis', 'vox', 'w64', 'wav', 'wavpcm', 'wv', 'wve', 'xa', 'xi']\n" - ] - } - ], - "source": [ - "print(sox.core._get_valid_formats())" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "wireless-hypothetical", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float64\n", - "(59471,)\n", - "16000\n", - "(54065,)\n", - "1.0999907518727459\n" - ] - } - ], - "source": [ - "import soundfile as sf\n", - "wav='/workspace/DeepSpeech-2.x/examples/aishell/s1/../../..//examples/dataset/aishell/data_aishell/wav/dev/S0724/BAC009S0724W0190.wav'\n", - "samples, sr = sf.read(wav)\n", - "print(samples.dtype)\n", - "print(samples.shape)\n", - "print(sr)\n", - "tfm = sox.Transformer()\n", - "tfm.speed(1.1)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "output_array.dtype\n", - "print(output_array.shape)\n", - "print(len(samples)/len(output_array))" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "designed-fluid", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import IPython.display as ipd\n", - "ipd.Audio(wav) # load a local WAV file" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "cultural-friendship", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfm = sox.Transformer()\n", - "tfm.speed(1.0)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "ipd.Audio(output_array, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "fossil-lotus", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfm = sox.Transformer()\n", - "tfm.speed(1.1)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "ipd.Audio(output_array, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "constitutional-poker", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfm = sox.Transformer()\n", - "tfm.speed(0.9)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "ipd.Audio(output_array, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "threaded-strap", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "66078\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8K0lEQVR4nO2dd3hUZfbHvycdQoAEQpEWmlQVJICoKApqABdcF8u6KlbUXX+77rqrIFZs7Lr2si5WXHXtrigI0myoSFB67xDpoYQEUs/vj7kTJpM7M/fO7XPP53nmye33ZObe97zveU8hZoYgCILgX5KcFkAQBEFwFlEEgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPscURUBEBUS0log2ENF4lf1/IaJVRLSMiOYSUYeQfWOJaL3yGWuGPIIgCIJ2yGgcARElA1gH4DwAOwAsAvBbZl4Vcsw5ABYycxkR3QJgCDNfRkQ5AAoB5ANgAIsB9GPmA4aEEgRBEDRjxohgAIANzLyJmSsAvANgdOgBzDyfmcuU1R8AtFWWLwAwm5mLlcZ/NoACE2QSBEEQNJJiwjXaANgesr4DwMAox18P4PMo57aJdcPmzZtzXl6ePikFQRB8zuLFi/cxc274djMUgWaI6EoEzEBnx3HuOADjAKB9+/YoLCw0WTpBEITEhoi2qm03wzRUBKBdyHpbZVu4AMMATAQwipnL9ZwLAMw8hZnzmTk/N7eeQhMEQRDixAxFsAhAVyLqSERpAC4HMC30ACLqC+DfCCiBPSG7ZgE4n4iyiSgbwPnKNkEQBMEmDJuGmLmKiG5FoAFPBvAqM68kokkACpl5GoDHADQC8D4RAcA2Zh7FzMVE9CACygQAJjFzsVGZBEEQBO0Ydh91gvz8fJY5AkEQBH0Q0WJmzg/fLpHFgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgif42/tLsXirpKASBCuwNbJYEOLl/cU7kJ6ahH4dsp0WRRASDhkRCIKQcKwoOoSZK3Y5LYZnEEUgCELCMeGj5bj5zcVOi+EZRBEInsGDsY+CQzDkYdGDKALBM8irLWhFOg36EEUgCELCIYpAH6IIBEFIOEQP6EMUgeAZ3l64zWkRBI/gxWSaTiKKQBAEweeIIhAEIaGorK7Bml0lTovhKUQRCIKQUBworXBaBM8hikAQhISiRqYHdCOKQPAUFVU1OPHuz50WQ3AxNTJRrBtTFAERFRDRWiLaQETjVfafRUQ/EVEVEY0J21dNREuUz7TwcwUhlKMV1aioqnFaDMHFiCLQj+Hso0SUDOB5AOcB2AFgERFNY+ZVIYdtA3ANgL+qXOIoM/cxKofgE8hpAQS3UyP9BN2YMSIYAGADM29i5goA7wAYHXoAM29h5mUA5CcSdDFj+U7kjZ/utBiCh5ARgX7MUARtAGwPWd+hbNNKBhEVEtEPRHSRCfIkNHnjp2P/kXKnxbCNtWFugEkyIhBiMPKZb5wWwXO4YbK4AzPnA7gCwFNE1FntICIapyiMwr1799orocvYU+IfRRDetyMSTSBEp7SiunZ5n486TUYwQxEUAWgXst5W2aYJZi5S/m4C8CWAvhGOm8LM+cycn5ubG7+0HmD+mj1Ysv1gve3HKgMPeLWf/ONkmC8YYM9hUQRaMEMRLALQlYg6ElEagMsBaPL+IaJsIkpXlpsDOAPAquhnJT7Xvr4If3jrp3rb/zFzLQB/KYLw/3Tu6t2OyCF4E5kv0IZhRcDMVQBuBTALwGoA7zHzSiKaRESjAICI+hPRDgCXAPg3Ea1UTu8BoJCIlgKYD2BymLeRr8kbP73OROn+0kDv5kCZfyMn//TOEgCSVEzQhjwm2jCleD0zzwAwI2zbvSHLixAwGYWf9x2Ak8yQIdFQM4UHH+oHPl2FId1a2CuQQ0R6kcurapCRmmyvMB6kpoZRw4yUZDdMB9qPjAi04c+nw8UcVHr7as9vsZJDZfO+UjtFcpRIJQeHPfGVzZJ4k7/PWoPu98x0WgzH+GHTfqdF8ASiCFzG8qJDqts37DmCbzfss1ka54nUodtx4Ki9gniUVb8cRpWP5pQ+WVLXT0UUgTZEEXiEV77d7LQIjuCfJsxanGgQDx+rxLIdB2295+qddeNOkpOkidOCfEseoVri5gUDXD7lB9vv+cQX6zDquQW23zcUn06N6Ea+Jo/gp+F9KDLXZ4yFm4odu3eVCzovBb1bOS2CJxBF4BFq/KoIxDhkiIpq5xrjJAeiwMOfl7Rk8SzTgigCl1J0sO5kaLWG9rDkWCXKKqosksghRA94FicUgRAfoghcRiRTiJYRwbAnvnLEFmwloge8i+gB72BKQJlgPdOX74x5zO7D5ThyLMFGBIJnkRGBd5ARgcsw+u5UJ9jsqqSS8C4rIsTEWMl7i7bHPkiohygCB2FmVJk8mZdoCelED3iTWSt3YeFm+z2WDpRV1lkXZwNtiCJwkBe+3IguE80txJ5oikDwJou3HqhdPmPyPAclEbQgisBBPltW3+4f+gLFQ6LpgQT7d3xDqIUz3ANOcB8yWewgq3cerrftqTnrHZBEEMzjdy//gAUb3JHjp0hyUmlCRgSCq5E5Am9xoLTCNUoAAB79fA0A4L8/bkNpuXjURUIUgUuoqq7Bawu0J5Yrr6qus/7yN5vMFsn1/PX9pfhmvb/rVwPA7sPHsOvQMafFAAAcrayOfZADTPhoOeau2eO0GK5FFIFL2LK/FA98qr0427pdR+qsPzR9tdkiuZ4PFu/AB4t3OC2G4/z6+QU4+7H5TosRFSkx6m5EEbgE0hlA8Kvnvq1dTuSc6+L+F5tjVTUor6rvhjxzxS7bZYn0GF8/tdBeQQRdmKIIiKiAiNYS0QYiGq+y/ywi+omIqohoTNi+sUS0XvmMNUMerzH6uW9hJI7s2tcWmSaL23htwZaYx0xfthMb9hyJeZzfePyLtbbfkww9yYJTGFYERJQM4HkAwwH0BPBbIuoZdtg2ANcAeDvs3BwA9wEYCGAAgPuIKNuoTF5j6Y5DukcEwPGo20oHM0y6gT+8/RMeneE/01iQSHmoJMODoBUzRgQDAGxg5k3MXAHgHQCjQw9g5i3MvAxAeIt1AYDZzFzMzAcAzAZQYIJMniPaOxvpRT+s5BXya60C4Pj35tdGr7i0AgePVsY+UBCiYIYiaAMgNMHHDmWb1ecmFNESdHW6a4b6DpZcPMf/e/9ogsv+/T0qlDmBaC6RTjwaTipkv9bsMAPPTBYT0TgiKiSiwr17ve8yOPbVH+usx/MCMRiLthiLRE4U9h4pd1oES7lhaiEOllUAABZuLsbhY4FRQFKSuxSgk9IcijAyyhs/3WZJvIcZiqAIQLuQ9bbKNlPPZeYpzJzPzPm5ublxCeomvlpnXJkxu6McoJN8suQXAMDS7QedFcRi5qzejeUh2TyDvX09De/KXw4ldKMo44H4MUMRLALQlYg6ElEagMsBTNN47iwA5xNRtjJJfL6yzXfENyIQL40gLusYW8LRiuPBWkG3Wj05/x/6rO6E+okTP8cvCZQH6JVvowdV+t2MGg3DioCZqwDcikADvhrAe8y8kogmEdEoACCi/kS0A8AlAP5NRCuVc4sBPIiAMlkEYJKyzXfE6zXk10nScOL5/rwGA9hbUn58BdE7EOH7UlPqvu4V1TXYsr/UPAEd5sWv/BddbxamzBEw8wxmPpGZOzPzw8q2e5l5mrK8iJnbMnMmMzdj5l4h577KzF2Uz2tmyONF4mnGGJKLx2/0f3gOAODCZwMBhXqem68Vc+ScVbstm1j9YlXkCGIr3ZwPllVICnYDeGayONFRy0QaC2ZgW3H9Hp0f4wqqaxj7E3zCOJQ9JeXIGz8dO+PIMXTDG4VY8Ysy32By23n3/1ZE3Ldwk3WD/QofPvNmIorAJcQTgs9g3Pnh8nrbS3xat7jfQ3NwJIEzTKqN/lbF0YEArIk9eXZu9BTqKckWmu9kMGAIUQQeJpJZyM+eREcSUAlGGy2GNq16TCMXv/CdAYnUeXz2uqj7Uy1UBFpqdYsZNTKiCDxMpBffz7bSRExSN/zpbwAAN7+5uN6+X0JMQ+EeQJo8ymycY09Jsq65qapOvN/dTkQReJhIj77VL8XmfaU49/EvLb2HoI1nQswxySE+tC99vUlTL9lORj+/wLJra/lXE7GTYBaiCDxMJM+PvUfK6/icm83S7Qexaa873Q79Oj8C1B0JPjxjta8mz2tcpvS8higCBzArsGXwP9SLkVz8wnf44zs/m3IPNYI9q2U7Dlp2j3g5/8mvbb/n/iPl2HnI+cCscM8ZTSbCBGk/tfwbEnwZGVEEDvB+ofVVtfYctq50YVCPXfHSQsvu4SUum/IDBj06z2kx6plH/NRJ1jIiCDUN7Tl8DP+cZX+9BrciisABNu61vohKekqyZdcOvnNHyqtw+FhlbQI0vxJMBuc04SPNkjhcafcfKa/NbBqNPSXHcOGz3+i+vlXoVXqzVu7Cc/M3WCOMBxFF4AQ2jFDTUqz5af/15UY8FtKTuuj5BRj5tHsaBD9jxgCg30Nz8Pjs2D3ldbuOYEXRYcx3SUF4LebW0EN8NFjShCgCB7DDVmmFz/aew8fwxOy12BVidtq0txTbDzhvHw+lyvYoU3fYnoMN3UOfrTJ0nc0aHAGCHkrXvu6OMqnSsBtDFIEDvPjVRsvvkZps/k874JG5qPSAv/Zpj841fI1v1+/DnCh5c9xI0Ab+2bKdms+pZq430f39pv0xz7M0SjgONM0RuP/RdQxRBAlKsh/yMkdg3xFjNvtjldW48pWFuOENbWk/gpHcTuf6DwaU6/GXf+P7rXFNdLvt8dLSyIcqC1EKdRFFkKDYnZX5UFklikvdMWlqlEgpHf73cxG2F5fV2ZY3fjoOlrljsjyoAPQElu+LM9bAbQ2plhFBoVLNb8n2g1KbIAxRBIIpjHnxO5zzzy+dFsMUfh0hD89t7y7B8Ke/qW1EIjWi63eXOJLmI9i21dYs0EC8YrqtGdXSrr9bGCiPftHzC7B2t/Wee15CFEGCYnehlvV7jkSsGZtIHCmvqv0/Rz+nnjLhvCe/xmfLfjHlfnrNTXoVULw943g71Fa5TuuVp9rHiRnVEEVgM1p8tM3AZSZcz3Kssn6qji9W7cayHQexN4pZZXtxme3eSzXMuhv2eEcu8SqQZIs6KHrzCM1zidurW0hxWgC/cfZj6mkh3M6bP2x1WgRHmLbkF1zav12dbXd8sCzmef/8Yh3++UUgLfO8289Gp9xGlsgXCrN+k42aHtDiaBCvaSj82oFyq8aVQ2m5vtxaRh0KEg1TRgREVEBEa4loAxGNV9mfTkTvKvsXElGesj2PiI4S0RLl86IZ8riZeCpKxYOeouZaWLBhX8xj3ORSmDd+OraaUI+3vEpfA7NbJbXHLpt+83gaZ7XEhdecnhf7PBMmW//435/RccIMw9cBgN++9IMp1/ErhhUBESUDeB7AcAA9AfyWiHqGHXY9gAPM3AXAkwD+HrJvIzP3UT43G5VHCGD2CFzL9aqqGTsOlMU+0CYueMp4AroXv9qky3wy8JH6MQzlNpmInpu3Xr+tXOWEBqmx05METexNG6YC0G4qenjG6trlaUvNmUcRjGPGiGAAgA3MvImZKwC8A2B02DGjAUxVlj8AMJTsns0UbGFFUXylE63gWKXxBrjo4FHsKTHWo4+3ULxeO/yc1XtQVqEvv1BQyYUXtYl5niLbwbJArimtrsMzV+zSdR/BHsxQBG0AbA9Z36FsUz2GmasAHALQTNnXkYh+JqKviGiwCfIIMG+y+Ok565E3frpmO266RTmOnOSLlbt1N7ChBL+6eWt266oTce7jX+m+l16dEwyGGxXiAaXlpw5VbuWVNVEnztX4Zv3e2uUPF1ufjTfIT9sO2HYvL+H0W7sTQHtm7gvgLwDeJqLGagcS0TgiKiSiwr1796odIoRg1oDryTlKHVqNDcy/vrQ+fYaVbNtf37R137SVmLE8/p5sMLfUda8X4sOftDd6m/fpn+PQa7sPVrNT844KsmT7wXoeUKHmshpm3VXxrnrlx9rl299fqutcI1z1sqROV8MMRVAEINStoq2yTfUYIkoB0ATAfmYuZ+b9AMDMiwFsBHCi2k2YeQoz5zNzfm5urgli28/8tfa5rH38cxEmf77GtOuVa3R7/XFLsWn3NAO9LpwPfLpSdbtZIyyrA830KoKgPITjpqjwpIgXPb8A05fXzV8UepcaZs/UyRaLtDpmKIJFALoSUUciSgNwOYBpYcdMAzBWWR4DYB4zMxHlKpPNIKJOALoC2GSCTK5ES1ZHMzFzGOxESL4ZMRcHXJD+Yc2uktrl+6apKxqz0Psz7VGikEvKq2o7DkT1A79Ckw3OWbUbEz5aXrs+6NF5ltYjNpMjcdRo8AOGFYFi878VwCwAqwG8x8wriWgSEY1SDnsFQDMi2oCACSjoYnoWgGVEtASBSeSbmdldXUoPk2JiZjAn+nvvLtpm+BoPTzeWkjmIkY7k32fWHZlZ2Xs2cu3QrKNDH/8KRREmkN9auDXuHEWRsHO0LNTHlIAyZp4BYEbYtntDlo8BuETlvA8BfGiGDF7g5+0Hbb2fmcnQnJhkqzAh5fUeHXl3AHsUXg0zki2K/TaiCJbtOATguPkkkreTFeaVb9fvwzndWph+3Vh8sHgHxvRra/t93YZEFtvIpzb7Ta/epc2V81hlNVKSCClRahjoUSo7DpRh6ndbUF0D3Pur8JASe9HTZkWbMDXa9oUGmllpZRv13LeGrxEcSB4N+T5CzXRWqLDDDuWpWhMh06zfcNpryDfocRs0C2Zttv3THp2LP79X13PDSGTuut0leOXbzXh1wea4zg9i57TeoaOV6H7PTBRZVG0tNNDsCgujYM2YE3ng04A57faQZ+KRkEAwo7y1sH66khnLtRfTMROZOw4gisAmzAjJj4d/zFqLT5YURU1NfLCsEou3FGNEWO3hrSqulFogIlMK45jxkgZNJU/MXocNeyJnvgyaQdbuLlHdb2Z50cKt3vBlX150qHY5NysdAOpVM4uHiR+vqLet1IGOEmB+KhavIorAJpxSBKt3Hsaf3lmC/g/PiXrcL4eOYVXYMDleiaurzUkkZsYrGjRzPzN3PSZ8tCyiDf0HDeUZ7cSuLLWxCH5fF/RqBSDgITTXosydRmstx4O4kwYQRWATTqU//3KtvuC7UL/7eJVXZnpKxHTDI5/5RrOZzAzV2TgjtXZ50ZYDmLtavQ7xLW/9FPU6T89db4I0x4kVLOaW2g6/ejYw53CssrreiNFsXv7WmCkxHtxWctMpRBHYhFpyLzdSUV1TayapjtNr595PVkQ066z85TCKy7TlpTHjK5uzeje+XX88c2pVnF418UT5RiPWpG60iWs7CY4Sd6mMGBMBtyhcpxFFYBNujbwM9xXfsq+sdlu8jeb6PUdUba/BiWu7g9OufOV4WgG1uYvwOsR2EKuhj/e7F/Tx1sJtEWtU+wlRBDYRy0bvFGdMnldnfcQz39TWHq4wkD5ZLYIz2LZpNZO99p35poLMtPoe004M1oKms7zx01UnsZ3yoonEzJXHcy01zrDO6/z8J/Un2gO0FdOJxHCLTV5eQBSBDdz3SX0vCSdYExJXMH/tnohFV4K9UbMnLINzDlU1NZpGBduLzXflvOGNRfW2GVF48ZKcRLXRudtVajhMX+YuRRDK4WPWpWlYF2dReSdSoCQSoghsYOr37ijzOPW7LVinuEde+9oifLY0emNz1GQ7ddA8du7jX+EFh7KUqtUoqM2waiPJSVSraIN/3164Fa8qE6bSrGmnvKpad/ptoS6iCHzEf3/cjoKnvq5tkGOl/73nf8ZHMmUVVTh0tBJ546fjlAe+qN3+3LwNMc8deXJrw/ePRkVVDUqOVWKPSnlJqzl8rKr2d2jeKA3b9pfhro9XYJLiQtmmaYbtMrmFw8f0TeA+aILbad746Yav4WVEEVhMvNWprKKGgVIDRVb00vPeWfjzu0sA1E1lfUxDLeDcRumWyJQ3fjqKSysw/sNlOOn+L9CqSQNL7hOLexST4dpdR/C/JXUztw/u6s1U62Zw8v1f6HpvzPbo8iOSa8hiXvgyds/XbuzuAc9TCUA6s0tzW2UI59QHZ9cun9G5me15oIDjMR53fby83r5KB+Yt3ERVDSNNmQA+Ul6FBRv2ITcrHae2z6537Jqd6tHgejlUVokmDVNjH5iAiCKwmNW7zHlIzWTYE8aLuhulb/tsrN55GN1bZalGd57ywBe2+Xh/t9FdUcX//XEbSn2eNz9oNnt9wWbcr+Q+atIgFUvvO7/esfs11kuOxZ0fLsOLV/Uz5VpeQ0xDVuMuy5BrWLPzMIY//Q1mrdxdp3e+YMM+5I2fbmugzzQHRgPRmPDRcvy07aDTYjjK1uJSfL1ub60SUMNsZbm86BAu+/f3pl7TK4giEBzhi1WBVA83v7kYxaUVtUFdoYnO/IzfRwQFT32Dq1/9sc628FCBXvfNwnuLtqNBarIp9yw6eBQLNxdj0KOBTLHrd5dETdaYSIhpyGJ2OeCR4kUG/2M+crPSffPixcIrGUrtpKyiGoVbivHsvA14/NJTAAB3fLjM9PvsPHQMX6zchXH/WQwAmP3ns9C1ZRYAYPqyX3Bqh2y0dsjBwCrIi4EY+fn5XFhY6LQYmjj7H/Ox1YEUBoKQyJzYslHcwWfxsO6h4UhOInS+K1CIccvkkZbfc29JeW3678rqGqQmJ+FIeRUapcfffyeixcycH77dFNMQERUQ0Voi2kBE41X2pxPRu8r+hUSUF7JvgrJ9LRFdYIY8bkKUgCCYz+7D9o4c+z74Be6ftrLOttLyKqwwYMqsqWG8vXAbqmsYh8oq8dScdTikFBZiZvR/eA7W7y5BRVUNuk78HPPW7Ebv+2ZhT4n5VgbDioCIkgE8D2A4gJ4AfktE4fUJrwdwgJm7AHgSwN+Vc3sCuBxALwAFAF5QrqeL8qpqw+kQmFk1TL2sogoFT31de/2t+0s1ZYZk5noyScpbIRx5JOLD7qyhpeXV+M8PxzMElFdVo9d9s3ChkqZ7za7DWKc02sFU7rHaiU37juCuj5dj9c7DmPr9Fjw1Zz3+/fVG/LTtAK55LZAKZe+R8tpMtde9HrCC3PTGYny4eAfOffzLOterqKrBzBU7sWmv+kgp2ryTGXMEAwBsYOZNAEBE7wAYDSB0un80gPuV5Q8APEcBn8HRAN5h5nIAm4log3K9mFP3P24uxvbiMgzslIORz3wb0M4PXIAMZeLoaEU1UpPr1+FlDhRNYWbUMPDLwaNol9MQHSfMQIusdPw4cRiYGW8u3IaUJEJmegrW7CrBDW8U4ut1Ab/viSN6YGCnHMxcsQt/HNoV24vL0LVlFm57Zwn+t6QIn/3fmXh67nrMXlU3973LYssEFyCPhDfpdvfM2uW1u0pQ8NTxxHXDerTE0B4tMOGj5dgyeSTW7S5Bg9RkNG2Yite/24K8ZpnISE3Ggg2B9OgXPvstRvc5AQDwwpcb66Rf+deXG7EmzAX95+0H8fP2gwACOalyMtNQdKAMM1bswrw1e9C0QSp+3bcNGqQlY39pBfI7ZKO0vAr3f7oKyY1btFX7fwzPERDRGAAFzHyDsn4VgIHMfGvIMSuUY3Yo6xsBDERAOfzAzG8q218B8DkzfxDtnrkde3LmZY+p7rtyYHvkZKbhmXkbMKBjDpo0SEWbpg3QPqchpny9ybLJ2yHdcnUXgREEQbCTopduPly5f3uT8O2e8RoionEAxgFAcuNcZEY47s2F22qXf9xcbINkAUQJCILgerhG1V5lxmRxEYB2IettlW2qxxBRCoAmAPZrPBcAwMxTmDmfmfMzsuqHmQfp1jIL/Toc39+0QSo65WaiQKm5ahVm+TILgiBYiOq0lBkjgkUAuhJRRwQa8csBXBF2zDQAYxGw/Y8BMI+ZmYimAXibiJ4AcAKArgB+RAy6t8rCXy8+CYVbinHVoDxc8u/vUVFVg2X3n19bo3Z7cRmaNkxFVkbd3CF7S8rBzMhMT8Huw8ewvOgQRvdpU5t9cM2DBQCAcf9ZjM65mejaIgt3fbwcHZo1xNb9AQ+g24Z1Rc/WjfGvLzfi4V+fhJ2HjuLc7i1w4t2fo7Ka8faNA3HrWz9rLskoCIJ3GT+8OyZ/vqZ2/aQ2TTCgYw5e+XYzFt41FOt3l6BhegraZjfArW//jNM65aCiirF212HMVywJw3q0wJzV9XNytctugO0HItflGN67FU5u2xTbikvx07aDWKvMJ/RonYWiA0eRkZqMgt6tsOPAUSXnl3oRWVPiCIhoBICnACQDeJWZHyaiSQAKmXkaEWUA+A+AvgCKAVweMrk8EcB1AKoA3MbMn8e6X3gcweZ9pchMT0aLrPhT927bX4ZmjdKQGeaju/9IOfo9NAcrHwh4ti7cvB8nt22K5hEyY9bUMJKSCKXlVThWWY1+D7mzMpkgCPGRlpKE78efW/tub5k8El+v24OU5CSc2DILqclJaJyRgq37y5DXPJIRG1hRdAgXPvstXrumP37adgDPztuA35zaBoO75uI2JWPvf64fgD/+92ccKDvuJdU4IwW3DOmMv89cWyee4UBpBf799Uac17NVHasIEHCSWV50CKe0y1aNI5CAMovxe55zQbCCJg1SbXchHd67FT5fESjZuWXySGzaewRLdxzEr/uqOuLEpKq6BvdNW4l7LuyJ3YeP4bZ3l+CZy/uiXU5DMDM6TpiBz/7vTLRukoF+D83B45ecgtvfX4o5fzkbHZtn4lhldb2OaywiBZSJIrCYbnd/XicPvyAIxmnZON3WoLJFE4chMz0ZPe+dBcCeyOKftx1An3ZNQUQoOngUJzTJwNrdJejeqnHc17Q0sliITPdWWU6LIAgJw4Th3QEAb1w30NL7jDurU+3y2zcORG5WOhqmpeCOgm6YdusZlt47SN/22bUp2ts0bQAiMqQEouEZ91Gv0ja7IZbukIyasXjrhoHo1yEbz83bgOfmu6+Yj92c3LYJlslzU4eczDTcdHZn3HR259ptE0f0wJNz1qGswtz62neN6IHTOuWgXXbD2oRzAPD7IV1MvY9bkBGB4AjBkdJ4pYd3RpfmyEhNRo/W1vR4vEbDNH/30aZeN6D22QhSHRaaP/f2s3HtGXmmKYHGGYHvPGj2Obd7yzpKIJERRWA1kkxGlQt6tcKLV56KGwd3wuZHR9RuH3FSK6yaZG/uwbxmDW29Xywu6dcW+R0ix8r4gf552bj57M64elCH2m3h85mdcxvVSyFjhJPaNrHF9u9GRBFYTEsDLq1W8fq1/Z0WAZv3laKgd2skJ1GdUpVEhIZpKdgyeSQuy28X5QrmcXOIqcENPHbJKWjq09q5QVKSAk3TpNG9sWD8ubhxcEc8MLqX6rHpKeY0Y5MvPtmU63gRUQQWc0dBN6dFqEdflQLgdvPpstjlIRumWxetPfvPZ2FAxxwAzlVFa9Ig0NgX9GqFMf3quiCm+DxVbXLI/9+maQNMHNkzopvmSW3qpc6Ji3Y57hoZ2okoAovJcGHqicw0+2T6YcJQFN49LK5zj5o8ARhkzYMF6NoyCy9dnY+v/3YO1u0uiX2SBXz2f2cCAG48q2MdEwgAX08Uz7xtcB1FEItWTdw36vYaogh8RM/WjTH1ugG1dtW/XRB9tHLrOcY9JFo1yUDzRunY+MiI2vQdAHDlwA5RzgrwzqLthu+vRlA5N2mQivbNGjqirBtnpCBJaez2H6nAyW2b4sbBnXDlaYHvZfUuZ5STG9DrIvnQRb0N39OvcwNBRBHYQE+XeMI8evFJOPvEXACBRv6KAe2jHh8tPF4vyUmEJGUu4MNbBuG+X4XXLnKOOwu6xz7IZGoYSFa+j1TFxj1xZI/aRs3fhiF9NG2YJkWfDCKKwAZm/Gmw0yIACPimB/nrBd2QnZkW9fg0kybhggSH+7mNMjR5e1hhJ3/p6npBlbWJCu2kqqam1qTRPLN+3qrg/IUbCc5tuAlSz6UmaEQUgU189bchToug+rJ88oe6UZKvXdsfH94yCACQlhz/y6VWYDvYridpfOomjOgR9/0jofYf1TiQZiXoE7/mwQKc1Lb+ZGf4nIHThKZxtzLHT7wmmvAYAz24wYvOaUQR2ER4Omy3cEq7pnXW++floE+7gFdRitYWWwVWKcIYVERae29m9fEmX3xS7XJFdf28T+0d8BYJfreR5ieMfPeCdkae3BpDurVwWgzHkafNJrxiw8xISaqVVY/nRigf3DwINVHy7GVlaIuaNeM769chG5f1jx6PkKThRlk6szzG4sWr+kXdn5Hqrlczp1F0M6JX6eBjl9FQ3PW0JTBaGhsrGKjT1pySnBTSc4/vngfKKiOaW7ZMHmmrTT40YK1Jg9TayfJwbhvWNep17jV5cjuSHEEau8QOv/GRQNR344xUyz1rnEjQaMCilFCIIrCJZIcms/KaZWJMv7b49NYzox6n1vtPilPm1GQyxe5uxjsa/LfO6NIML12dHzF/+zWn50W9jtkT57FwS/xJ8Ln4cfN+AMCHt5yOId2iK7F4+dwBpwovpuG3An9ntvIBD4zupalR6Z+XjXfGDaqzLTdLvQpbLJiNTd6FXscoQWX21g2nabpXVkYKSo5VWSKL12ib3aB2+adtBwEETG1GOzV/PLcLnplXP8OsE54/PvxZVZERgU3orSRkFlqUwMK7hmJKmFvluoeGx50JtE12Awzp1gJnxTB/2IHWtiU7Mw1f/nUI+uepm9LUJr/18PHvT69dXvtQQZQjnSeYFuXpy/vWbrvnQvNMY38+78R6235zanxVvoxSI7YhAAYVARHlENFsIlqv/FVNYkNEY5Vj1hPR2JDtXxLRWiJaonwSevq+a4tGtt6vo8aAsJaNM+rZ7Y2YQk5smYVXr+mPN64bEPc1APt7a3nNMyN6KhkdEYTmd4rX5KYFM8wrGSmBzkOLkBFhaOyA0d9FrefvVGxCy8aSngIwPiIYD2AuM3cFMFdZrwMR5QC4D8BAAAMA3BemMH7HzH2Uzx6D8riaKwZGj+Q1m9Ym5mAZ2t1+HW2G/TbVxDTFZmGlIjDSoAbTccea37HCrt7zBGei7687s6Mj93UbRt+S0QCmKstTAVykcswFAGYzczEzHwAwG4C7x8YWYbeducrEYa8Tc92X9DOehvrB0cbz0ADGfrvwdNrxuuVqwYiSOTWkBsIrY/PRpmkD1ePOtaBTEJ591S6s/C28hFFF0JKZdyrLuwC0VDmmDYDQ7GE7lG1BXlPMQvdQgseJN7PZF7tdtrd9pJuYkJO/ReP4JrzNpKD38ajc+y3OsaS3XQt1CHjk14HAO2ZgaI+WEV2erxqUhykhcRBf/nUI/vcHe+r4GsXseJBEIaYiIKI5RLRC5TM69DgOjBf19pt+x8wnARisfK6KIsc4IiokosK9e/fqvI07GHXKCbbda1iPlnjkYnN6w4D2OYNmMfIX2U16ij43zBtDipaHYtbYysyKWmro7UsF8zllpafUOhaET4w/NubkOsoMqDvyaJCW7Jn6CTI1rE7Mp5KZhzFzb5XPJwB2E1FrAFD+qtn4iwCEjo3bKtvAzMG/JQDeRmAOIZIcU5g5n5nzc3Od90aJBzsHPFkZKbobQTWG9QiYAbTK/s9LTzF8Tyc5rVOzetsu6nMCBndtHvc1gw3ruLM64byeaoNm89Br6khR8klVR7F9XZLfrl7uqND7EOm/75+GHg/gs9MsdP8o9Spnfsdo92QagKAX0FgAn6gcMwvA+USUrUwSnw9gFhGlEFFzACCiVAAXAlhhUB5BwaxEai+P7R+IKNV4ucqqKLklPMqdw7sb8i4J/hR3jeih6zrrHx6u+156O+bBmIC3b4weZxFOaL8gLTlJd2nNUBfSf15iX+fBqbkIt2PUYDYZwHtEdD2ArQAuBQAiygdwMzPfwMzFRPQggEXKOZOUbZkIKIRUAMkA5gB4yaA8gkVoVSztXVYI3gyyGxozd8Ufoa2/n9ZAZ/W5oKmqT0jyQS0/dWhSvKYN09BE4/Oh1aVZsBdDioCZ9wMYqrK9EMANIeuvAng17JhSANEzbwlx40QkbGoy6a4uZSXBUpBGmKQxMjvIWzcMxO9eXlhnm13pKf73hzNAOnO2qkUJa/E2C0+OqtV0GBpbkpOZhuLSCk3nCdbiPidrwRTMzrGvJRiusto9U3ErHrgAvU0oaq63V35Gl/pzCXrNJvFC0O/mq3b8f3/cFvM8M2IhFt89DD/eVa8fGRdPX97HlOv4FVEENhNvIXenuW1Y/bQAbkatME48qE3uXnN6Hp67oi/SoiiJX/dtg6X3no8fJw5FrxOMKyQtEOmv4aDWoB/WUHgmXjUQ2kEhIrQwKbI3r5mYnIwgisBmmjeyx6/d7L65U2m0nUbt97ptWFdcePIJUXvfAzvmoEnDVLTIsi+FAYF0e6bFG1AVrwfcsUprnAn0jlAuzZdJ41BEESQodqfXzbbJ/OEGgtXmpv9RPa/P69f2x0V926ju08vP95yn+dh43Djj1e/xWoZObGlNvi3dJjHT6t8lBhJmJ5jCtFvPRKVKGUgv8tq1/XHta4vqbR87qANG9TmhtrHtEmHexMzSh9k6AvSCjWFuVjr2lpRrOsfukZ5VsTRaLts/L5BC47ExJ6umGvczMiJIUOz2GmqX0xCdcu3NrmoVfdo2Vd3+wOje6Nehbprqn3T02K0m2MvV09RmqAQdanl03GYp1NLDD+auuiS/neQYCkNGBAmKGYVh/Ep2Zhomje6FI+Wxe405mWlo3igd+46UW17KMRbBXrGeTvcNgzvizuHddd/LbYV6wt1Z1ZDGPzIyInCAm85Wz2djJmZmHg0Sq9ylW1hsgmfW1YPy8PshXUyQxj6CE6anRBjRqNEgNblOMJnW893kKgzot/kndnpL/YgiSFCssNef1LZJPW+LrIwUZKS66zFqZpNn1nHc0SgGG7cXfneqoeu0y4kdHR4ccd41Qv9owgq0dPZDG3/RA3UR05AT2NBulFuU8+cfY07BwI7NcPv7SwE4U3BcUCfYuBnJcDrztsER6xCEEqxnPO6sznHfy0wSPIO95YgicIAcG1I1V1iY/C3UFt3W4zUPEonwxjArPQUlGuY5QtGaIiSveabjcyKhaNEDoccM6twMp7Zvapk8XsNdY3qfcL0N5fGszA8ffKGeuqyPZffwEu/eNMiUvEZGCW8MNXWSE6QjrTegrEuLLHz0e28U07EDUQQOYFZxkk8iVIV6ZWw+/nWl9fn8RvcxJ2jKTF6/tr/t9+yc28iUvEZGSQ1znfGTl4w2neef70Mvogg8TPMs9UnRri2y6pQgNJsuuVmWXdso3Vq5VzarCRaZAYABHXO0pc92xzy3YcxIgudnRBF4mEiPfmiDYAUntW3iKvuwn/nNqce9uEJjR967aVBc9Qys5M3rB1p2bdEDxnDXkyLoItLQ3yv1YwVtvHjlqXX+hpKvpE0AUG8UGF572Gn0Fs3Rg5/MYFYgisDDROoF+fmlaKCjiIxXKOjdWlmK/rvqKaBjRe98aPfoOZaqLMxFpaXzI6OGyIgicAmTRusvqk0g3H5e/ToBmSbl4vcas247C00NlpX0Gh00BH+p0aSB+dliX7km+kS9FdHutUgjbwhDioCIcohoNhGtV/5mRzhuJhEdJKLPwrZ3JKKFRLSBiN4lIn+9xSEM61G/AEosiIB+efW/cj09w0QhOYl8N1G8ZfLIiBlQo3HN6XnHzzO5Af3tgHYR91npWSWTxcYwOiIYD2AuM3cFMFdZV+MxAFepbP87gCeZuQuAAwCuNyiPZ4mnr0QQlzi/MUOpgbBQKfGo57kZ0i0XAHD/qF6W2ev/rDJCDWLFKCSIXQWfEhWjimA0gKnK8lQAF6kdxMxzAZSEbqNAGOS5AD6Idb4fsLuQTKLhB3VIBPQ8oXHtMqAvC6hatHlLk0pF1uLgYzzqlBOcu7nHMaoIWjLzTmV5FwA99o1mAA4yczAGfgcA90Uo2UQ8ekDyqxzHUvuzSwg1+QVHgjU6HpxrTs+rs75l8kh0TpAaEkBgpCPER0xFQERziGiFymd06HEc6NJa9jYS0TgiKiSiwr1791p1G9toZkK+IYKMJEb3CfQCO+UmdvHytJQkdA2ZD6gdEei4xvm9WmHVpAvMFcxFxOoWSccpMjEVATMPY+beKp9PAOwmotYAoPzdo+Pe+wE0JaKgi0tbAEVR5JjCzPnMnJ+bm6vjNu5kcVhlq3ja8yQi1cliP9KpeWIrgnUPDccJIVlBGymeYXo7Ag3TrPUoc7Jb0ihD/X9b82CBzZJ4D6OmoWkAxirLYwF8ovVEZQQxH8CYeM5PNKIN8SM+yASkq5QaFBKbLZNH1pqJ0lwWPezkADVSJLUfvej0YvQpmgzgPCJaD2CYsg4iyieil4MHEdE3AN4HMJSIdhBRcHx6J4C/ENEGBOYMXjEoj2eJ9v5EepAb+TReIJTgYN+vFrIWUSZ7/fqdCPox1JIw834AQ1W2FwK4IWRdtXoJM28CMMCIDIlCPLb+YARxw7RklFVUmy2SJ2AAN53VCWd2be60KI6Rk5mG4tIKp8UQPIy7xpU+ZcvkkYZsq3eP7GmaLG7j4lNjO5JNGNEDg7t6f97IbC7NjxzcZRVuy28UikwVR0YUgUvQOyB447rjA6krBrY3WRr30LSBb4PNNXOsUn00eONZnWyWJDK/H+KOkpaCOqIIXMIJTTMwWId5o1+Hut5C5/XUn6LC63TOzcSZXfxrEgry0e9Px/Q/Ol8hDYjcobmjwB1F7gV1RBG4hIZpKfiPjoyQ4YnlXro632yRXM/c24fgEgfMH26je6vG6HWC8xXSAOtrYcTLyJNbY0DHHKfFcC3idiK4GokB8hYtskxOWWGQ684I1Ad//or6tRyE48iIwEHUcudf0Mt/Jh4hsdj4yAjcOLij02IAAE7t0NRpETyBKAIHUZvkvfK0DoaumWg96AT7d3xBchJJOgePIaYhB7mjoBtuOttczw4pUym4gbxmx1N+bHpkhIOSCFqQEYGDpKckm25TTbQCHQn27/iG3w5oVxv5nuRg50TqdWhDFIHLMJoWINHqFYuJwZsQES7vb79Hl9kjbL8gisAjFPRqpek4K6tACYIefFAiImGQOQKXEakDrMU/+9s7z4mYgdGryHjAu+gpmmMWYgqKj8RqNRKINiG55wFtk8BtsxuaX3rQaeS99ixOFE1yc64jNyOKwCMkJ/nzp5Ienndxg2nIiVGJF/Fn6+JBEszioxmZKzaGnvxVZlPtAtPQnNW7bZfBi8gcgUdING8gwV6euPQU2+95y9md0bddU9vvG0ppuT/rdOjFp/1M7zHqlNh5+RMRUX/mcFEf+5+fdjkNbU8K2Kpxep11MQ1pQxSBy+ic20h1+6DOzdDH4d6VE0T0opIRkiaCzgNOBnXZydWD8uqsd22h/j4JdTGkCIgoh4hmE9F65W92hONmEtFBIvosbPvrRLSZiJYonz5G5EkETgjzFgqlfU5DAECrRPMMikKkyeIF48+1WRJv8tBFvbFo4jCnxbCNcIU34qTWDkniLYyOCMYDmMvMXQHMVdbVeAzAVRH2/Y2Z+yifJQblSWiCveMJI/xT5CPSiKBFVrr6DqEOGanJyPXxd5VoKVeswqgiGA1gqrI8FcBFagcx81wAJQbv5TuW338+1j88vHY9RXEh7RCS0Mtv3HpOFwCSekLQhjwm2jCqCFoy805leReAeJLpP0xEy4joSSLyb9clhDsKuuGOgm7IykitEyl8z4U9AADJPnq6w//Tm6X2raADGRFoI6b7KBHNAaCW6GZi6AozMxHpnaKfgIACSQMwBcCdACZFkGMcgHEA0L594hZrB4DfD+miur1pw0Ahd1+5koa9yE5EqwrepWlDyb2lhZiKgJkjzjQR0W4ias3MO4moNYA9em4eMpooJ6LXAPw1yrFTEFAWyM/P93VrkJ3pn4fbRypPsIBozhfCcYyahqYBGKssjwXwiZ6TFeUBChh8LwKwwqA8Cc+mR0agdRP/PNwtwvzCfd0DEDQx67aznBbBcxhVBJMBnEdE6wEMU9ZBRPlE9HLwICL6BsD7AIYS0Q4iukDZ9RYRLQewHEBzAA8ZlCfh8Ys/eJDf9m+PxXcfH5SKZUiIhVotcCE6hlJMMPN+AENVthcCuCFkfXCE88UZXIhKUhKhWaOQUYEoAiEGMj+sH4ksFjxFeqo8skJ0/DZqNgN5qwRPkZGajC2TRzothuBiRA/oRxSBIAgJhZ/ibMxCFIEgCAlFY6nbrRtRBIIgJBQZqcno3irLaTE8hSgCQRAEnyOKQPAMF/U5wWkRBI8gSQn1IYpA8AwN06WyqqANUQP6EEUgCELCIQMCfYgiEDyDvNuCVkQR6EMUgeAZ5OUWtBKpxKmgjhhdBUFIOP5wThes/OWQ02J4BlEEgiAkHAW9W6Ggt1o9LUENMQ0JnqBzbibO7JLrtBiCkJDIiEDwBHNvH+K0CIKQsMiIQBAEweeIIhAEQfA5oggEQRB8jiFFQEQ5RDSbiNYrf7NVjulDRN8T0UoiWkZEl4Xs60hEC4loAxG9S0RpRuQRBEEQ9GN0RDAewFxm7gpgrrIeThmAq5m5F4ACAE8RUVNl398BPMnMXQAcAHC9QXkEQRAEnRhVBKMBTFWWpwK4KPwAZl7HzOuV5V8A7AGQS4H0gOcC+CDa+YIgCIK1GFUELZl5p7K8C0DLaAcT0QAAaQA2AmgG4CAzVym7dwBoY1AeQRAEQScx4wiIaA4AtRC9iaErzMxExFGu0xrAfwCMZeYavfnCiWgcgHEA0L59e13nCoIgCJGJqQiYeVikfUS0m4haM/NOpaHfE+G4xgCmA5jIzD8om/cDaEpEKcqooC2AoihyTAEwRbleCRGtjSW7gzQHsM9pIWLgdhndLh8gMpqB2+UD3C+jHvk6qG00Glk8DcBYAJOVv5+EH6B4An0M4A1mDs4HBEcQ8wGMAfBOpPMjsJaZ8w3KbhlEVOhm+QD3y+h2+QCR0QzcLh/gfhnNkM/oHMFkAOcR0XoAw5R1EFE+Eb2sHHMpgLMAXENES5RPH2XfnQD+QkQbEJgzeMWgPIIgCIJODI0ImHk/gKEq2wsB3KAsvwngzQjnbwIwwIgMgiAIgjG8Glk8xWkBYuB2+QD3y+h2+QCR0QzcLh/gfhkNy0fMER19BEEQBB/g1RGBIAiCYBKeUgREVEBEa5XcRGrpLFwlDxFdQ0R7QybJb3BCzjCZXiWiPUS0wmlZgNjyENEQIjoU8h3ea7eMKjK1I6L5RLRKyaH1JzfL4tLvMIOIfiSipYrcD7hZFje+y0GIKJmIfiaiz+K+CDN74gMgGYGI5E4IRCcvBdDTzfIAuAbAc05/d2EynQXgVAArnJZFizwAhgD4zGk5w2RqDeBUZTkLwDqnnkUtsrj0OyQAjZTlVAALAZzmVlnc+C6HyPYXAG8b+Y29NCIYAGADM29i5goEYg9Gizz6YOavARQ7LUcQt8mjBWbeycw/KcslAFbDofQobpJFDxzgiLKaqnwcmbB0kyx6IaK2AEYCeDnWsdHwkiJoA2B7yLrTuYm0yvMbJf32B0TUzh7REo5ByrD9cyLq5bQwoRBRHoC+CPQiHSWGLK77DhWTxhIEMhLMZmbHvkONsrjxXX4KwB0AaoxcxEuKwIt8CiCPmU8GMBvHM7UK2vkJQAdmPgXAswD+56w4xyGiRgA+BHAbMx92sSyu/A6ZuZqZ+yCQXmYAEfV2sSyue5eJ6EIAe5h5sdFreUkRFAEI1cJRcxPZQEx5mHk/M5crqy8D6GeTbAkDMx8ODtuZeQaAVCJq7rBYIKJUBBret5j5IzfL4tbvMAgzHwQwH4F6JY4SSRaXvstnABhFRFsQME2fS0Sqwbux8JIiWASgKwWqmqUBuByBXEeulUdJxBdkFAL2W0EHRNSKlFS1FEhjnoRAwkInZSIE0qGsZuYn3C6LS7/DXFIKVBFRAwDnAVjjVlnc+C4z8wRmbsvMeQi0P/OY+cp4rmU06ZxtMHMVEd0KYBYCHjuvMvNKt8lDRJMAFDLzNAB/JKJRAKoQmBC9xil5gxDRfxHwImlORDsA3MfMjuV4UpMHgck6MPOLCCQlvIWIqgAcBXA5K64SDnIGgKsALFfsygBwl9LbdoUsANoDrv4OWwOYSkTJCCim95g5fvdHC2Rx+7tsJhJZLAiC4HO8ZBoSBEEQLEAUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPkcUgSBEgYiahWSc3EVERcryESJ6wWn5BMEMxH1UEDRCRPcDOMLM/3RaFkEwExkRCEIcKDn+P1OW7yeiqUT0DRFtJaKLiegfRLSciGYqKSBARP2I6CsiWkxEs8KiVQXBMUQRCII5dAZwLgLpB94EMJ+ZT0IgknekogyeBTCGmfsBeBXAw04JKwiheCbFhCC4nM+ZuZKIliOQcmSmsn05gDwA3QD0BjBbSfuTDGCnA3IKQj1EEQiCOZQDADPXEFFlSC6fGgTeMwKwkpkHOSWgIERCTEOCYA9rAeQS0SAgkDraLQViBEEUgSDYgFLOdAyAvxPRUgBLAJzuqFCCoCDuo4IgCD5HRgSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPgcUQSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPic/wcvziJ0eY2VRAAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "librosa.display.waveplot(samples_out, sr=sr)\n", - "print(len(samples_out))" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "infectious-welcome", - "metadata": {}, - "outputs": [], - "source": [ - "import librosa\n", - "x, sr = librosa.load(wav, sr=16000)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "musical-anatomy", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float32\n", - "float64\n" - ] - } - ], - "source": [ - "print(x.dtype)\n", - "print(samples.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "lucky-paraguay", - "metadata": {}, - "outputs": [], - "source": [ - "sf.read?" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "annual-christmas", - "metadata": {}, - "outputs": [], - "source": [ - "librosa.load?" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "infectious-seeker", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.allclose(x, samples)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "pregnant-conditioning", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import random" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "logical-happiness", - "metadata": {}, - "outputs": [], - "source": [ - "np.random.uniform?" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "rocky-plastic", - "metadata": {}, - "outputs": [], - "source": [ - "random.uniform?" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "id": "focused-compensation", - "metadata": {}, - "outputs": [], - "source": [ - "np.random.RandomState?" - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "id": "centered-repository", - "metadata": {}, - "outputs": [], - "source": [ - "random.sample?" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "id": "inner-invite", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array(['3', '5'], dtype=' 1.0, speed up the audio;\n", - " speed_rate = 1.0, unchanged;\n", - " speed_rate < 1.0, slow down the audio;\n", - " speed_rate <= 0.0, not allowed, raise ValueError.\n", - " :type speed_rate: float\n", - " :raises ValueError: If speed_rate <= 0.0.\n", - " \"\"\"\n", - " if speed_rate <= 0:\n", - " raise ValueError(\"speed_rate should be greater than zero.\")\n", - " old_length = samples.shape[0]\n", - " new_length = int(old_length / speed_rate)\n", - " old_indices = np.arange(old_length)\n", - " new_indices = np.linspace(start=0, stop=old_length, num=new_length)\n", - " samples = np.interp(new_indices, old_indices, samples)\n", - " return samples" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "tracked-purse", - "metadata": {}, - "outputs": [], - "source": [ - "samples, sr = sf.read(wav)\n", - "samples_out = change_speed(samples, 1.0)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "steady-mileage", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ipd.Audio(samples, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "regulated-google", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ipd.Audio(samples_out, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "homeless-forge", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples_out = change_speed(samples, 1.1)\n", - "ipd.Audio(samples_out, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "exciting-blocking", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples_out = change_speed(samples, 0.9)\n", - "ipd.Audio(samples_out, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "through-botswana", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "66078\n" - ] - } - ], - "source": [ - "print(len(samples_out))" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "cellular-violence", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting matplotlib\n", - " Downloading matplotlib-3.4.1-cp37-cp37m-manylinux1_x86_64.whl (10.3 MB)\n", - "\u001b[K |████████████████████████████████| 10.3 MB 691 kB/s eta 0:00:01\n", - "\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (8.1.0)\n", - "Requirement already satisfied: numpy>=1.16 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (1.20.1)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (2.8.1)\n", - "Collecting kiwisolver>=1.0.1\n", - " Downloading kiwisolver-1.3.1-cp37-cp37m-manylinux1_x86_64.whl (1.1 MB)\n", - "\u001b[K |████████████████████████████████| 1.1 MB 45.9 MB/s eta 0:00:01\n", - "\u001b[?25hRequirement already satisfied: pyparsing>=2.2.1 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (2.4.7)\n", - "Collecting cycler>=0.10\n", - " Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)\n", - "Requirement already satisfied: six in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n", - "Installing collected packages: kiwisolver, cycler, matplotlib\n", - "Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1\n" - ] - } - ], - "source": [ - "!pip install matplotlib\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "import librosa.display" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "undefined-parade", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8K0lEQVR4nO2dd3hUZfbHvycdQoAEQpEWmlQVJICoKApqABdcF8u6KlbUXX+77rqrIFZs7Lr2si5WXHXtrigI0myoSFB67xDpoYQEUs/vj7kTJpM7M/fO7XPP53nmye33ZObe97zveU8hZoYgCILgX5KcFkAQBEFwFlEEgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPscURUBEBUS0log2ENF4lf1/IaJVRLSMiOYSUYeQfWOJaL3yGWuGPIIgCIJ2yGgcARElA1gH4DwAOwAsAvBbZl4Vcsw5ABYycxkR3QJgCDNfRkQ5AAoB5ANgAIsB9GPmA4aEEgRBEDRjxohgAIANzLyJmSsAvANgdOgBzDyfmcuU1R8AtFWWLwAwm5mLlcZ/NoACE2QSBEEQNJJiwjXaANgesr4DwMAox18P4PMo57aJdcPmzZtzXl6ePikFQRB8zuLFi/cxc274djMUgWaI6EoEzEBnx3HuOADjAKB9+/YoLCw0WTpBEITEhoi2qm03wzRUBKBdyHpbZVu4AMMATAQwipnL9ZwLAMw8hZnzmTk/N7eeQhMEQRDixAxFsAhAVyLqSERpAC4HMC30ACLqC+DfCCiBPSG7ZgE4n4iyiSgbwPnKNkEQBMEmDJuGmLmKiG5FoAFPBvAqM68kokkACpl5GoDHADQC8D4RAcA2Zh7FzMVE9CACygQAJjFzsVGZBEEQBO0Ydh91gvz8fJY5AkEQBH0Q0WJmzg/fLpHFgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgif42/tLsXirpKASBCuwNbJYEOLl/cU7kJ6ahH4dsp0WRRASDhkRCIKQcKwoOoSZK3Y5LYZnEEUgCELCMeGj5bj5zcVOi+EZRBEInsGDsY+CQzDkYdGDKALBM8irLWhFOg36EEUgCELCIYpAH6IIBEFIOEQP6EMUgeAZ3l64zWkRBI/gxWSaTiKKQBAEweeIIhAEIaGorK7Bml0lTovhKUQRCIKQUBworXBaBM8hikAQhISiRqYHdCOKQPAUFVU1OPHuz50WQ3AxNTJRrBtTFAERFRDRWiLaQETjVfafRUQ/EVEVEY0J21dNREuUz7TwcwUhlKMV1aioqnFaDMHFiCLQj+Hso0SUDOB5AOcB2AFgERFNY+ZVIYdtA3ANgL+qXOIoM/cxKofgE8hpAQS3UyP9BN2YMSIYAGADM29i5goA7wAYHXoAM29h5mUA5CcSdDFj+U7kjZ/utBiCh5ARgX7MUARtAGwPWd+hbNNKBhEVEtEPRHSRCfIkNHnjp2P/kXKnxbCNtWFugEkyIhBiMPKZb5wWwXO4YbK4AzPnA7gCwFNE1FntICIapyiMwr1799orocvYU+IfRRDetyMSTSBEp7SiunZ5n486TUYwQxEUAWgXst5W2aYJZi5S/m4C8CWAvhGOm8LM+cycn5ubG7+0HmD+mj1Ysv1gve3HKgMPeLWf/ONkmC8YYM9hUQRaMEMRLALQlYg6ElEagMsBaPL+IaJsIkpXlpsDOAPAquhnJT7Xvr4If3jrp3rb/zFzLQB/KYLw/3Tu6t2OyCF4E5kv0IZhRcDMVQBuBTALwGoA7zHzSiKaRESjAICI+hPRDgCXAPg3Ea1UTu8BoJCIlgKYD2BymLeRr8kbP73OROn+0kDv5kCZfyMn//TOEgCSVEzQhjwm2jCleD0zzwAwI2zbvSHLixAwGYWf9x2Ak8yQIdFQM4UHH+oHPl2FId1a2CuQQ0R6kcurapCRmmyvMB6kpoZRw4yUZDdMB9qPjAi04c+nw8UcVHr7as9vsZJDZfO+UjtFcpRIJQeHPfGVzZJ4k7/PWoPu98x0WgzH+GHTfqdF8ASiCFzG8qJDqts37DmCbzfss1ka54nUodtx4Ki9gniUVb8cRpWP5pQ+WVLXT0UUgTZEEXiEV77d7LQIjuCfJsxanGgQDx+rxLIdB2295+qddeNOkpOkidOCfEseoVri5gUDXD7lB9vv+cQX6zDquQW23zcUn06N6Ea+Jo/gp+F9KDLXZ4yFm4odu3eVCzovBb1bOS2CJxBF4BFq/KoIxDhkiIpq5xrjJAeiwMOfl7Rk8SzTgigCl1J0sO5kaLWG9rDkWCXKKqosksghRA94FicUgRAfoghcRiRTiJYRwbAnvnLEFmwloge8i+gB72BKQJlgPdOX74x5zO7D5ThyLMFGBIJnkRGBd5ARgcsw+u5UJ9jsqqSS8C4rIsTEWMl7i7bHPkiohygCB2FmVJk8mZdoCelED3iTWSt3YeFm+z2WDpRV1lkXZwNtiCJwkBe+3IguE80txJ5oikDwJou3HqhdPmPyPAclEbQgisBBPltW3+4f+gLFQ6LpgQT7d3xDqIUz3ANOcB8yWewgq3cerrftqTnrHZBEEMzjdy//gAUb3JHjp0hyUmlCRgSCq5E5Am9xoLTCNUoAAB79fA0A4L8/bkNpuXjURUIUgUuoqq7Bawu0J5Yrr6qus/7yN5vMFsn1/PX9pfhmvb/rVwPA7sPHsOvQMafFAAAcrayOfZADTPhoOeau2eO0GK5FFIFL2LK/FA98qr0427pdR+qsPzR9tdkiuZ4PFu/AB4t3OC2G4/z6+QU4+7H5TosRFSkx6m5EEbgE0hlA8Kvnvq1dTuSc6+L+F5tjVTUor6rvhjxzxS7bZYn0GF8/tdBeQQRdmKIIiKiAiNYS0QYiGq+y/ywi+omIqohoTNi+sUS0XvmMNUMerzH6uW9hJI7s2tcWmSaL23htwZaYx0xfthMb9hyJeZzfePyLtbbfkww9yYJTGFYERJQM4HkAwwH0BPBbIuoZdtg2ANcAeDvs3BwA9wEYCGAAgPuIKNuoTF5j6Y5DukcEwPGo20oHM0y6gT+8/RMeneE/01iQSHmoJMODoBUzRgQDAGxg5k3MXAHgHQCjQw9g5i3MvAxAeIt1AYDZzFzMzAcAzAZQYIJMniPaOxvpRT+s5BXya60C4Pj35tdGr7i0AgePVsY+UBCiYIYiaAMgNMHHDmWb1ecmFNESdHW6a4b6DpZcPMf/e/9ogsv+/T0qlDmBaC6RTjwaTipkv9bsMAPPTBYT0TgiKiSiwr17ve8yOPbVH+usx/MCMRiLthiLRE4U9h4pd1oES7lhaiEOllUAABZuLsbhY4FRQFKSuxSgk9IcijAyyhs/3WZJvIcZiqAIQLuQ9bbKNlPPZeYpzJzPzPm5ublxCeomvlpnXJkxu6McoJN8suQXAMDS7QedFcRi5qzejeUh2TyDvX09De/KXw4ldKMo44H4MUMRLALQlYg6ElEagMsBTNN47iwA5xNRtjJJfL6yzXfENyIQL40gLusYW8LRiuPBWkG3Wj05/x/6rO6E+okTP8cvCZQH6JVvowdV+t2MGg3DioCZqwDcikADvhrAe8y8kogmEdEoACCi/kS0A8AlAP5NRCuVc4sBPIiAMlkEYJKyzXfE6zXk10nScOL5/rwGA9hbUn58BdE7EOH7UlPqvu4V1TXYsr/UPAEd5sWv/BddbxamzBEw8wxmPpGZOzPzw8q2e5l5mrK8iJnbMnMmMzdj5l4h577KzF2Uz2tmyONF4mnGGJKLx2/0f3gOAODCZwMBhXqem68Vc+ScVbstm1j9YlXkCGIr3ZwPllVICnYDeGayONFRy0QaC2ZgW3H9Hp0f4wqqaxj7E3zCOJQ9JeXIGz8dO+PIMXTDG4VY8Ysy32By23n3/1ZE3Ldwk3WD/QofPvNmIorAJcQTgs9g3Pnh8nrbS3xat7jfQ3NwJIEzTKqN/lbF0YEArIk9eXZu9BTqKckWmu9kMGAIUQQeJpJZyM+eREcSUAlGGy2GNq16TCMXv/CdAYnUeXz2uqj7Uy1UBFpqdYsZNTKiCDxMpBffz7bSRExSN/zpbwAAN7+5uN6+X0JMQ+EeQJo8ymycY09Jsq65qapOvN/dTkQReJhIj77VL8XmfaU49/EvLb2HoI1nQswxySE+tC99vUlTL9lORj+/wLJra/lXE7GTYBaiCDxMJM+PvUfK6/icm83S7Qexaa873Q79Oj8C1B0JPjxjta8mz2tcpvS8higCBzArsGXwP9SLkVz8wnf44zs/m3IPNYI9q2U7Dlp2j3g5/8mvbb/n/iPl2HnI+cCscM8ZTSbCBGk/tfwbEnwZGVEEDvB+ofVVtfYctq50YVCPXfHSQsvu4SUum/IDBj06z2kx6plH/NRJ1jIiCDUN7Tl8DP+cZX+9BrciisABNu61vohKekqyZdcOvnNHyqtw+FhlbQI0vxJMBuc04SPNkjhcafcfKa/NbBqNPSXHcOGz3+i+vlXoVXqzVu7Cc/M3WCOMBxFF4AQ2jFDTUqz5af/15UY8FtKTuuj5BRj5tHsaBD9jxgCg30Nz8Pjs2D3ldbuOYEXRYcx3SUF4LebW0EN8NFjShCgCB7DDVmmFz/aew8fwxOy12BVidtq0txTbDzhvHw+lyvYoU3fYnoMN3UOfrTJ0nc0aHAGCHkrXvu6OMqnSsBtDFIEDvPjVRsvvkZps/k874JG5qPSAv/Zpj841fI1v1+/DnCh5c9xI0Ab+2bKdms+pZq430f39pv0xz7M0SjgONM0RuP/RdQxRBAlKsh/yMkdg3xFjNvtjldW48pWFuOENbWk/gpHcTuf6DwaU6/GXf+P7rXFNdLvt8dLSyIcqC1EKdRFFkKDYnZX5UFklikvdMWlqlEgpHf73cxG2F5fV2ZY3fjoOlrljsjyoAPQElu+LM9bAbQ2plhFBoVLNb8n2g1KbIAxRBIIpjHnxO5zzzy+dFsMUfh0hD89t7y7B8Ke/qW1EIjWi63eXOJLmI9i21dYs0EC8YrqtGdXSrr9bGCiPftHzC7B2t/Wee15CFEGCYnehlvV7jkSsGZtIHCmvqv0/Rz+nnjLhvCe/xmfLfjHlfnrNTXoVULw943g71Fa5TuuVp9rHiRnVEEVgM1p8tM3AZSZcz3Kssn6qji9W7cayHQexN4pZZXtxme3eSzXMuhv2eEcu8SqQZIs6KHrzCM1zidurW0hxWgC/cfZj6mkh3M6bP2x1WgRHmLbkF1zav12dbXd8sCzmef/8Yh3++UUgLfO8289Gp9xGlsgXCrN+k42aHtDiaBCvaSj82oFyq8aVQ2m5vtxaRh0KEg1TRgREVEBEa4loAxGNV9mfTkTvKvsXElGesj2PiI4S0RLl86IZ8riZeCpKxYOeouZaWLBhX8xj3ORSmDd+OraaUI+3vEpfA7NbJbXHLpt+83gaZ7XEhdecnhf7PBMmW//435/RccIMw9cBgN++9IMp1/ErhhUBESUDeB7AcAA9AfyWiHqGHXY9gAPM3AXAkwD+HrJvIzP3UT43G5VHCGD2CFzL9aqqGTsOlMU+0CYueMp4AroXv9qky3wy8JH6MQzlNpmInpu3Xr+tXOWEBqmx05METexNG6YC0G4qenjG6trlaUvNmUcRjGPGiGAAgA3MvImZKwC8A2B02DGjAUxVlj8AMJTsns0UbGFFUXylE63gWKXxBrjo4FHsKTHWo4+3ULxeO/yc1XtQVqEvv1BQyYUXtYl5niLbwbJArimtrsMzV+zSdR/BHsxQBG0AbA9Z36FsUz2GmasAHALQTNnXkYh+JqKviGiwCfIIMG+y+Ok565E3frpmO266RTmOnOSLlbt1N7ChBL+6eWt266oTce7jX+m+l16dEwyGGxXiAaXlpw5VbuWVNVEnztX4Zv3e2uUPF1ufjTfIT9sO2HYvL+H0W7sTQHtm7gvgLwDeJqLGagcS0TgiKiSiwr1796odIoRg1oDryTlKHVqNDcy/vrQ+fYaVbNtf37R137SVmLE8/p5sMLfUda8X4sOftDd6m/fpn+PQa7sPVrNT844KsmT7wXoeUKHmshpm3VXxrnrlx9rl299fqutcI1z1sqROV8MMRVAEINStoq2yTfUYIkoB0ATAfmYuZ+b9AMDMiwFsBHCi2k2YeQoz5zNzfm5urgli28/8tfa5rH38cxEmf77GtOuVa3R7/XFLsWn3NAO9LpwPfLpSdbtZIyyrA830KoKgPITjpqjwpIgXPb8A05fXzV8UepcaZs/UyRaLtDpmKIJFALoSUUciSgNwOYBpYcdMAzBWWR4DYB4zMxHlKpPNIKJOALoC2GSCTK5ES1ZHMzFzGOxESL4ZMRcHXJD+Yc2uktrl+6apKxqz0Psz7VGikEvKq2o7DkT1A79Ckw3OWbUbEz5aXrs+6NF5ltYjNpMjcdRo8AOGFYFi878VwCwAqwG8x8wriWgSEY1SDnsFQDMi2oCACSjoYnoWgGVEtASBSeSbmdldXUoPk2JiZjAn+nvvLtpm+BoPTzeWkjmIkY7k32fWHZlZ2Xs2cu3QrKNDH/8KRREmkN9auDXuHEWRsHO0LNTHlIAyZp4BYEbYtntDlo8BuETlvA8BfGiGDF7g5+0Hbb2fmcnQnJhkqzAh5fUeHXl3AHsUXg0zki2K/TaiCJbtOATguPkkkreTFeaVb9fvwzndWph+3Vh8sHgHxvRra/t93YZEFtvIpzb7Ta/epc2V81hlNVKSCClRahjoUSo7DpRh6ndbUF0D3Pur8JASe9HTZkWbMDXa9oUGmllpZRv13LeGrxEcSB4N+T5CzXRWqLDDDuWpWhMh06zfcNpryDfocRs0C2Zttv3THp2LP79X13PDSGTuut0leOXbzXh1wea4zg9i57TeoaOV6H7PTBRZVG0tNNDsCgujYM2YE3ng04A57faQZ+KRkEAwo7y1sH66khnLtRfTMROZOw4gisAmzAjJj4d/zFqLT5YURU1NfLCsEou3FGNEWO3hrSqulFogIlMK45jxkgZNJU/MXocNeyJnvgyaQdbuLlHdb2Z50cKt3vBlX150qHY5NysdAOpVM4uHiR+vqLet1IGOEmB+KhavIorAJpxSBKt3Hsaf3lmC/g/PiXrcL4eOYVXYMDleiaurzUkkZsYrGjRzPzN3PSZ8tCyiDf0HDeUZ7cSuLLWxCH5fF/RqBSDgITTXosydRmstx4O4kwYQRWATTqU//3KtvuC7UL/7eJVXZnpKxHTDI5/5RrOZzAzV2TgjtXZ50ZYDmLtavQ7xLW/9FPU6T89db4I0x4kVLOaW2g6/ejYw53CssrreiNFsXv7WmCkxHtxWctMpRBHYhFpyLzdSUV1TayapjtNr595PVkQ066z85TCKy7TlpTHjK5uzeje+XX88c2pVnF418UT5RiPWpG60iWs7CY4Sd6mMGBMBtyhcpxFFYBNujbwM9xXfsq+sdlu8jeb6PUdUba/BiWu7g9OufOV4WgG1uYvwOsR2EKuhj/e7F/Tx1sJtEWtU+wlRBDYRy0bvFGdMnldnfcQz39TWHq4wkD5ZLYIz2LZpNZO99p35poLMtPoe004M1oKms7zx01UnsZ3yoonEzJXHcy01zrDO6/z8J/Un2gO0FdOJxHCLTV5eQBSBDdz3SX0vCSdYExJXMH/tnohFV4K9UbMnLINzDlU1NZpGBduLzXflvOGNRfW2GVF48ZKcRLXRudtVajhMX+YuRRDK4WPWpWlYF2dReSdSoCQSoghsYOr37ijzOPW7LVinuEde+9oifLY0emNz1GQ7ddA8du7jX+EFh7KUqtUoqM2waiPJSVSraIN/3164Fa8qE6bSrGmnvKpad/ptoS6iCHzEf3/cjoKnvq5tkGOl/73nf8ZHMmUVVTh0tBJ546fjlAe+qN3+3LwNMc8deXJrw/ePRkVVDUqOVWKPSnlJqzl8rKr2d2jeKA3b9pfhro9XYJLiQtmmaYbtMrmFw8f0TeA+aILbad746Yav4WVEEVhMvNWprKKGgVIDRVb00vPeWfjzu0sA1E1lfUxDLeDcRumWyJQ3fjqKSysw/sNlOOn+L9CqSQNL7hOLexST4dpdR/C/JXUztw/u6s1U62Zw8v1f6HpvzPbo8iOSa8hiXvgyds/XbuzuAc9TCUA6s0tzW2UI59QHZ9cun9G5me15oIDjMR53fby83r5KB+Yt3ERVDSNNmQA+Ul6FBRv2ITcrHae2z6537Jqd6tHgejlUVokmDVNjH5iAiCKwmNW7zHlIzWTYE8aLuhulb/tsrN55GN1bZalGd57ywBe2+Xh/t9FdUcX//XEbSn2eNz9oNnt9wWbcr+Q+atIgFUvvO7/esfs11kuOxZ0fLsOLV/Uz5VpeQ0xDVuMuy5BrWLPzMIY//Q1mrdxdp3e+YMM+5I2fbmugzzQHRgPRmPDRcvy07aDTYjjK1uJSfL1ub60SUMNsZbm86BAu+/f3pl7TK4giEBzhi1WBVA83v7kYxaUVtUFdoYnO/IzfRwQFT32Dq1/9sc628FCBXvfNwnuLtqNBarIp9yw6eBQLNxdj0KOBTLHrd5dETdaYSIhpyGJ2OeCR4kUG/2M+crPSffPixcIrGUrtpKyiGoVbivHsvA14/NJTAAB3fLjM9PvsPHQMX6zchXH/WQwAmP3ns9C1ZRYAYPqyX3Bqh2y0dsjBwCrIi4EY+fn5XFhY6LQYmjj7H/Ox1YEUBoKQyJzYslHcwWfxsO6h4UhOInS+K1CIccvkkZbfc29JeW3678rqGqQmJ+FIeRUapcfffyeixcycH77dFNMQERUQ0Voi2kBE41X2pxPRu8r+hUSUF7JvgrJ9LRFdYIY8bkKUgCCYz+7D9o4c+z74Be6ftrLOttLyKqwwYMqsqWG8vXAbqmsYh8oq8dScdTikFBZiZvR/eA7W7y5BRVUNuk78HPPW7Ebv+2ZhT4n5VgbDioCIkgE8D2A4gJ4AfktE4fUJrwdwgJm7AHgSwN+Vc3sCuBxALwAFAF5QrqeL8qpqw+kQmFk1TL2sogoFT31de/2t+0s1ZYZk5noyScpbIRx5JOLD7qyhpeXV+M8PxzMElFdVo9d9s3ChkqZ7za7DWKc02sFU7rHaiU37juCuj5dj9c7DmPr9Fjw1Zz3+/fVG/LTtAK55LZAKZe+R8tpMtde9HrCC3PTGYny4eAfOffzLOterqKrBzBU7sWmv+kgp2ryTGXMEAwBsYOZNAEBE7wAYDSB0un80gPuV5Q8APEcBn8HRAN5h5nIAm4log3K9mFP3P24uxvbiMgzslIORz3wb0M4PXIAMZeLoaEU1UpPr1+FlDhRNYWbUMPDLwaNol9MQHSfMQIusdPw4cRiYGW8u3IaUJEJmegrW7CrBDW8U4ut1Ab/viSN6YGCnHMxcsQt/HNoV24vL0LVlFm57Zwn+t6QIn/3fmXh67nrMXlU3973LYssEFyCPhDfpdvfM2uW1u0pQ8NTxxHXDerTE0B4tMOGj5dgyeSTW7S5Bg9RkNG2Yite/24K8ZpnISE3Ggg2B9OgXPvstRvc5AQDwwpcb66Rf+deXG7EmzAX95+0H8fP2gwACOalyMtNQdKAMM1bswrw1e9C0QSp+3bcNGqQlY39pBfI7ZKO0vAr3f7oKyY1btFX7fwzPERDRGAAFzHyDsn4VgIHMfGvIMSuUY3Yo6xsBDERAOfzAzG8q218B8DkzfxDtnrkde3LmZY+p7rtyYHvkZKbhmXkbMKBjDpo0SEWbpg3QPqchpny9ybLJ2yHdcnUXgREEQbCTopduPly5f3uT8O2e8RoionEAxgFAcuNcZEY47s2F22qXf9xcbINkAUQJCILgerhG1V5lxmRxEYB2IettlW2qxxBRCoAmAPZrPBcAwMxTmDmfmfMzsuqHmQfp1jIL/Toc39+0QSo65WaiQKm5ahVm+TILgiBYiOq0lBkjgkUAuhJRRwQa8csBXBF2zDQAYxGw/Y8BMI+ZmYimAXibiJ4AcAKArgB+RAy6t8rCXy8+CYVbinHVoDxc8u/vUVFVg2X3n19bo3Z7cRmaNkxFVkbd3CF7S8rBzMhMT8Huw8ewvOgQRvdpU5t9cM2DBQCAcf9ZjM65mejaIgt3fbwcHZo1xNb9AQ+g24Z1Rc/WjfGvLzfi4V+fhJ2HjuLc7i1w4t2fo7Ka8faNA3HrWz9rLskoCIJ3GT+8OyZ/vqZ2/aQ2TTCgYw5e+XYzFt41FOt3l6BhegraZjfArW//jNM65aCiirF212HMVywJw3q0wJzV9XNytctugO0HItflGN67FU5u2xTbikvx07aDWKvMJ/RonYWiA0eRkZqMgt6tsOPAUSXnl3oRWVPiCIhoBICnACQDeJWZHyaiSQAKmXkaEWUA+A+AvgCKAVweMrk8EcB1AKoA3MbMn8e6X3gcweZ9pchMT0aLrPhT927bX4ZmjdKQGeaju/9IOfo9NAcrHwh4ti7cvB8nt22K5hEyY9bUMJKSCKXlVThWWY1+D7mzMpkgCPGRlpKE78efW/tub5k8El+v24OU5CSc2DILqclJaJyRgq37y5DXPJIRG1hRdAgXPvstXrumP37adgDPztuA35zaBoO75uI2JWPvf64fgD/+92ccKDvuJdU4IwW3DOmMv89cWyee4UBpBf799Uac17NVHasIEHCSWV50CKe0y1aNI5CAMovxe55zQbCCJg1SbXchHd67FT5fESjZuWXySGzaewRLdxzEr/uqOuLEpKq6BvdNW4l7LuyJ3YeP4bZ3l+CZy/uiXU5DMDM6TpiBz/7vTLRukoF+D83B45ecgtvfX4o5fzkbHZtn4lhldb2OaywiBZSJIrCYbnd/XicPvyAIxmnZON3WoLJFE4chMz0ZPe+dBcCeyOKftx1An3ZNQUQoOngUJzTJwNrdJejeqnHc17Q0sliITPdWWU6LIAgJw4Th3QEAb1w30NL7jDurU+3y2zcORG5WOhqmpeCOgm6YdusZlt47SN/22bUp2ts0bQAiMqQEouEZ91Gv0ja7IZbukIyasXjrhoHo1yEbz83bgOfmu6+Yj92c3LYJlslzU4eczDTcdHZn3HR259ptE0f0wJNz1qGswtz62neN6IHTOuWgXXbD2oRzAPD7IV1MvY9bkBGB4AjBkdJ4pYd3RpfmyEhNRo/W1vR4vEbDNH/30aZeN6D22QhSHRaaP/f2s3HtGXmmKYHGGYHvPGj2Obd7yzpKIJERRWA1kkxGlQt6tcKLV56KGwd3wuZHR9RuH3FSK6yaZG/uwbxmDW29Xywu6dcW+R0ix8r4gf552bj57M64elCH2m3h85mdcxvVSyFjhJPaNrHF9u9GRBFYTEsDLq1W8fq1/Z0WAZv3laKgd2skJ1GdUpVEhIZpKdgyeSQuy28X5QrmcXOIqcENPHbJKWjq09q5QVKSAk3TpNG9sWD8ubhxcEc8MLqX6rHpKeY0Y5MvPtmU63gRUQQWc0dBN6dFqEdflQLgdvPpstjlIRumWxetPfvPZ2FAxxwAzlVFa9Ig0NgX9GqFMf3quiCm+DxVbXLI/9+maQNMHNkzopvmSW3qpc6Ji3Y57hoZ2okoAovJcGHqicw0+2T6YcJQFN49LK5zj5o8ARhkzYMF6NoyCy9dnY+v/3YO1u0uiX2SBXz2f2cCAG48q2MdEwgAX08Uz7xtcB1FEItWTdw36vYaogh8RM/WjTH1ugG1dtW/XRB9tHLrOcY9JFo1yUDzRunY+MiI2vQdAHDlwA5RzgrwzqLthu+vRlA5N2mQivbNGjqirBtnpCBJaez2H6nAyW2b4sbBnXDlaYHvZfUuZ5STG9DrIvnQRb0N39OvcwNBRBHYQE+XeMI8evFJOPvEXACBRv6KAe2jHh8tPF4vyUmEJGUu4MNbBuG+X4XXLnKOOwu6xz7IZGoYSFa+j1TFxj1xZI/aRs3fhiF9NG2YJkWfDCKKwAZm/Gmw0yIACPimB/nrBd2QnZkW9fg0kybhggSH+7mNMjR5e1hhJ3/p6npBlbWJCu2kqqam1qTRPLN+3qrg/IUbCc5tuAlSz6UmaEQUgU189bchToug+rJ88oe6UZKvXdsfH94yCACQlhz/y6VWYDvYridpfOomjOgR9/0jofYf1TiQZiXoE7/mwQKc1Lb+ZGf4nIHThKZxtzLHT7wmmvAYAz24wYvOaUQR2ER4Omy3cEq7pnXW++floE+7gFdRitYWWwVWKcIYVERae29m9fEmX3xS7XJFdf28T+0d8BYJfreR5ieMfPeCdkae3BpDurVwWgzHkafNJrxiw8xISaqVVY/nRigf3DwINVHy7GVlaIuaNeM769chG5f1jx6PkKThRlk6szzG4sWr+kXdn5Hqrlczp1F0M6JX6eBjl9FQ3PW0JTBaGhsrGKjT1pySnBTSc4/vngfKKiOaW7ZMHmmrTT40YK1Jg9TayfJwbhvWNep17jV5cjuSHEEau8QOv/GRQNR344xUyz1rnEjQaMCilFCIIrCJZIcms/KaZWJMv7b49NYzox6n1vtPilPm1GQyxe5uxjsa/LfO6NIML12dHzF/+zWn50W9jtkT57FwS/xJ8Ln4cfN+AMCHt5yOId2iK7F4+dwBpwovpuG3An9ntvIBD4zupalR6Z+XjXfGDaqzLTdLvQpbLJiNTd6FXscoQWX21g2nabpXVkYKSo5VWSKL12ib3aB2+adtBwEETG1GOzV/PLcLnplXP8OsE54/PvxZVZERgU3orSRkFlqUwMK7hmJKmFvluoeGx50JtE12Awzp1gJnxTB/2IHWtiU7Mw1f/nUI+uepm9LUJr/18PHvT69dXvtQQZQjnSeYFuXpy/vWbrvnQvNMY38+78R6235zanxVvoxSI7YhAAYVARHlENFsIlqv/FVNYkNEY5Vj1hPR2JDtXxLRWiJaonwSevq+a4tGtt6vo8aAsJaNM+rZ7Y2YQk5smYVXr+mPN64bEPc1APt7a3nNMyN6KhkdEYTmd4rX5KYFM8wrGSmBzkOLkBFhaOyA0d9FrefvVGxCy8aSngIwPiIYD2AuM3cFMFdZrwMR5QC4D8BAAAMA3BemMH7HzH2Uzx6D8riaKwZGj+Q1m9Ym5mAZ2t1+HW2G/TbVxDTFZmGlIjDSoAbTccea37HCrt7zBGei7687s6Mj93UbRt+S0QCmKstTAVykcswFAGYzczEzHwAwG4C7x8YWYbeducrEYa8Tc92X9DOehvrB0cbz0ADGfrvwdNrxuuVqwYiSOTWkBsIrY/PRpmkD1ePOtaBTEJ591S6s/C28hFFF0JKZdyrLuwC0VDmmDYDQ7GE7lG1BXlPMQvdQgseJN7PZF7tdtrd9pJuYkJO/ReP4JrzNpKD38ajc+y3OsaS3XQt1CHjk14HAO2ZgaI+WEV2erxqUhykhcRBf/nUI/vcHe+r4GsXseJBEIaYiIKI5RLRC5TM69DgOjBf19pt+x8wnARisfK6KIsc4IiokosK9e/fqvI07GHXKCbbda1iPlnjkYnN6w4D2OYNmMfIX2U16ij43zBtDipaHYtbYysyKWmro7UsF8zllpafUOhaET4w/NubkOsoMqDvyaJCW7Jn6CTI1rE7Mp5KZhzFzb5XPJwB2E1FrAFD+qtn4iwCEjo3bKtvAzMG/JQDeRmAOIZIcU5g5n5nzc3Od90aJBzsHPFkZKbobQTWG9QiYAbTK/s9LTzF8Tyc5rVOzetsu6nMCBndtHvc1gw3ruLM64byeaoNm89Br6khR8klVR7F9XZLfrl7uqND7EOm/75+GHg/gs9MsdP8o9Spnfsdo92QagKAX0FgAn6gcMwvA+USUrUwSnw9gFhGlEFFzACCiVAAXAlhhUB5BwaxEai+P7R+IKNV4ucqqKLklPMqdw7sb8i4J/hR3jeih6zrrHx6u+156O+bBmIC3b4weZxFOaL8gLTlJd2nNUBfSf15iX+fBqbkIt2PUYDYZwHtEdD2ArQAuBQAiygdwMzPfwMzFRPQggEXKOZOUbZkIKIRUAMkA5gB4yaA8gkVoVSztXVYI3gyyGxozd8Ufoa2/n9ZAZ/W5oKmqT0jyQS0/dWhSvKYN09BE4/Oh1aVZsBdDioCZ9wMYqrK9EMANIeuvAng17JhSANEzbwlx40QkbGoy6a4uZSXBUpBGmKQxMjvIWzcMxO9eXlhnm13pKf73hzNAOnO2qkUJa/E2C0+OqtV0GBpbkpOZhuLSCk3nCdbiPidrwRTMzrGvJRiusto9U3ErHrgAvU0oaq63V35Gl/pzCXrNJvFC0O/mq3b8f3/cFvM8M2IhFt89DD/eVa8fGRdPX97HlOv4FVEENhNvIXenuW1Y/bQAbkatME48qE3uXnN6Hp67oi/SoiiJX/dtg6X3no8fJw5FrxOMKyQtEOmv4aDWoB/WUHgmXjUQ2kEhIrQwKbI3r5mYnIwgisBmmjeyx6/d7L65U2m0nUbt97ptWFdcePIJUXvfAzvmoEnDVLTIsi+FAYF0e6bFG1AVrwfcsUprnAn0jlAuzZdJ41BEESQodqfXzbbJ/OEGgtXmpv9RPa/P69f2x0V926ju08vP95yn+dh43Djj1e/xWoZObGlNvi3dJjHT6t8lBhJmJ5jCtFvPRKVKGUgv8tq1/XHta4vqbR87qANG9TmhtrHtEmHexMzSh9k6AvSCjWFuVjr2lpRrOsfukZ5VsTRaLts/L5BC47ExJ6umGvczMiJIUOz2GmqX0xCdcu3NrmoVfdo2Vd3+wOje6Nehbprqn3T02K0m2MvV09RmqAQdanl03GYp1NLDD+auuiS/neQYCkNGBAmKGYVh/Ep2Zhomje6FI+Wxe405mWlo3igd+46UW17KMRbBXrGeTvcNgzvizuHddd/LbYV6wt1Z1ZDGPzIyInCAm85Wz2djJmZmHg0Sq9ylW1hsgmfW1YPy8PshXUyQxj6CE6anRBjRqNEgNblOMJnW893kKgzot/kndnpL/YgiSFCssNef1LZJPW+LrIwUZKS66zFqZpNn1nHc0SgGG7cXfneqoeu0y4kdHR4ccd41Qv9owgq0dPZDG3/RA3UR05AT2NBulFuU8+cfY07BwI7NcPv7SwE4U3BcUCfYuBnJcDrztsER6xCEEqxnPO6sznHfy0wSPIO95YgicIAcG1I1V1iY/C3UFt3W4zUPEonwxjArPQUlGuY5QtGaIiSveabjcyKhaNEDoccM6twMp7Zvapk8XsNdY3qfcL0N5fGszA8ffKGeuqyPZffwEu/eNMiUvEZGCW8MNXWSE6QjrTegrEuLLHz0e28U07EDUQQOYFZxkk8iVIV6ZWw+/nWl9fn8RvcxJ2jKTF6/tr/t9+yc28iUvEZGSQ1znfGTl4w2neef70Mvogg8TPMs9UnRri2y6pQgNJsuuVmWXdso3Vq5VzarCRaZAYABHXO0pc92xzy3YcxIgudnRBF4mEiPfmiDYAUntW3iKvuwn/nNqce9uEJjR967aVBc9Qys5M3rB1p2bdEDxnDXkyLoItLQ3yv1YwVtvHjlqXX+hpKvpE0AUG8UGF572Gn0Fs3Rg5/MYFYgisDDROoF+fmlaKCjiIxXKOjdWlmK/rvqKaBjRe98aPfoOZaqLMxFpaXzI6OGyIgicAmTRusvqk0g3H5e/ToBmSbl4vcas247C00NlpX0Gh00BH+p0aSB+dliX7km+kS9FdHutUgjbwhDioCIcohoNhGtV/5mRzhuJhEdJKLPwrZ3JKKFRLSBiN4lIn+9xSEM61G/AEosiIB+efW/cj09w0QhOYl8N1G8ZfLIiBlQo3HN6XnHzzO5Af3tgHYR91npWSWTxcYwOiIYD2AuM3cFMFdZV+MxAFepbP87gCeZuQuAAwCuNyiPZ4mnr0QQlzi/MUOpgbBQKfGo57kZ0i0XAHD/qF6W2ev/rDJCDWLFKCSIXQWfEhWjimA0gKnK8lQAF6kdxMxzAZSEbqNAGOS5AD6Idb4fsLuQTKLhB3VIBPQ8oXHtMqAvC6hatHlLk0pF1uLgYzzqlBOcu7nHMaoIWjLzTmV5FwA99o1mAA4yczAGfgcA90Uo2UQ8ekDyqxzHUvuzSwg1+QVHgjU6HpxrTs+rs75l8kh0TpAaEkBgpCPER0xFQERziGiFymd06HEc6NJa9jYS0TgiKiSiwr1791p1G9toZkK+IYKMJEb3CfQCO+UmdvHytJQkdA2ZD6gdEei4xvm9WmHVpAvMFcxFxOoWSccpMjEVATMPY+beKp9PAOwmotYAoPzdo+Pe+wE0JaKgi0tbAEVR5JjCzPnMnJ+bm6vjNu5kcVhlq3ja8yQi1cliP9KpeWIrgnUPDccJIVlBGymeYXo7Ag3TrPUoc7Jb0ihD/X9b82CBzZJ4D6OmoWkAxirLYwF8ovVEZQQxH8CYeM5PNKIN8SM+yASkq5QaFBKbLZNH1pqJ0lwWPezkADVSJLUfvej0YvQpmgzgPCJaD2CYsg4iyieil4MHEdE3AN4HMJSIdhBRcHx6J4C/ENEGBOYMXjEoj2eJ9v5EepAb+TReIJTgYN+vFrIWUSZ7/fqdCPox1JIw834AQ1W2FwK4IWRdtXoJM28CMMCIDIlCPLb+YARxw7RklFVUmy2SJ2AAN53VCWd2be60KI6Rk5mG4tIKp8UQPIy7xpU+ZcvkkYZsq3eP7GmaLG7j4lNjO5JNGNEDg7t6f97IbC7NjxzcZRVuy28UikwVR0YUgUvQOyB447rjA6krBrY3WRr30LSBb4PNNXOsUn00eONZnWyWJDK/H+KOkpaCOqIIXMIJTTMwWId5o1+Hut5C5/XUn6LC63TOzcSZXfxrEgry0e9Px/Q/Ol8hDYjcobmjwB1F7gV1RBG4hIZpKfiPjoyQ4YnlXro632yRXM/c24fgEgfMH26je6vG6HWC8xXSAOtrYcTLyJNbY0DHHKfFcC3idiK4GokB8hYtskxOWWGQ684I1Ad//or6tRyE48iIwEHUcudf0Mt/Jh4hsdj4yAjcOLij02IAAE7t0NRpETyBKAIHUZvkvfK0DoaumWg96AT7d3xBchJJOgePIaYhB7mjoBtuOttczw4pUym4gbxmx1N+bHpkhIOSCFqQEYGDpKckm25TTbQCHQn27/iG3w5oVxv5nuRg50TqdWhDFIHLMJoWINHqFYuJwZsQES7vb79Hl9kjbL8gisAjFPRqpek4K6tACYIefFAiImGQOQKXEakDrMU/+9s7z4mYgdGryHjAu+gpmmMWYgqKj8RqNRKINiG55wFtk8BtsxuaX3rQaeS99ixOFE1yc64jNyOKwCMkJ/nzp5Ienndxg2nIiVGJF/Fn6+JBEszioxmZKzaGnvxVZlPtAtPQnNW7bZfBi8gcgUdING8gwV6euPQU2+95y9md0bddU9vvG0ppuT/rdOjFp/1M7zHqlNh5+RMRUX/mcFEf+5+fdjkNbU8K2Kpxep11MQ1pQxSBy+ic20h1+6DOzdDH4d6VE0T0opIRkiaCzgNOBnXZydWD8uqsd22h/j4JdTGkCIgoh4hmE9F65W92hONmEtFBIvosbPvrRLSZiJYonz5G5EkETgjzFgqlfU5DAECrRPMMikKkyeIF48+1WRJv8tBFvbFo4jCnxbCNcIU34qTWDkniLYyOCMYDmMvMXQHMVdbVeAzAVRH2/Y2Z+yifJQblSWiCveMJI/xT5CPSiKBFVrr6DqEOGanJyPXxd5VoKVeswqgiGA1gqrI8FcBFagcx81wAJQbv5TuW338+1j88vHY9RXEh7RCS0Mtv3HpOFwCSekLQhjwm2jCqCFoy805leReAeJLpP0xEy4joSSLyb9clhDsKuuGOgm7IykitEyl8z4U9AADJPnq6w//Tm6X2raADGRFoI6b7KBHNAaCW6GZi6AozMxHpnaKfgIACSQMwBcCdACZFkGMcgHEA0L594hZrB4DfD+miur1pw0Ahd1+5koa9yE5EqwrepWlDyb2lhZiKgJkjzjQR0W4ias3MO4moNYA9em4eMpooJ6LXAPw1yrFTEFAWyM/P93VrkJ3pn4fbRypPsIBozhfCcYyahqYBGKssjwXwiZ6TFeUBChh8LwKwwqA8Cc+mR0agdRP/PNwtwvzCfd0DEDQx67aznBbBcxhVBJMBnEdE6wEMU9ZBRPlE9HLwICL6BsD7AIYS0Q4iukDZ9RYRLQewHEBzAA8ZlCfh8Ys/eJDf9m+PxXcfH5SKZUiIhVotcCE6hlJMMPN+AENVthcCuCFkfXCE88UZXIhKUhKhWaOQUYEoAiEGMj+sH4ksFjxFeqo8skJ0/DZqNgN5qwRPkZGajC2TRzothuBiRA/oRxSBIAgJhZ/ibMxCFIEgCAlFY6nbrRtRBIIgJBQZqcno3irLaTE8hSgCQRAEnyOKQPAMF/U5wWkRBI8gSQn1IYpA8AwN06WyqqANUQP6EEUgCELCIQMCfYgiEDyDvNuCVkQR6EMUgeAZ5OUWtBKpxKmgjhhdBUFIOP5wThes/OWQ02J4BlEEgiAkHAW9W6Ggt1o9LUENMQ0JnqBzbibO7JLrtBiCkJDIiEDwBHNvH+K0CIKQsMiIQBAEweeIIhAEQfA5oggEQRB8jiFFQEQ5RDSbiNYrf7NVjulDRN8T0UoiWkZEl4Xs60hEC4loAxG9S0RpRuQRBEEQ9GN0RDAewFxm7gpgrrIeThmAq5m5F4ACAE8RUVNl398BPMnMXQAcAHC9QXkEQRAEnRhVBKMBTFWWpwK4KPwAZl7HzOuV5V8A7AGQS4H0gOcC+CDa+YIgCIK1GFUELZl5p7K8C0DLaAcT0QAAaQA2AmgG4CAzVym7dwBoY1AeQRAEQScx4wiIaA4AtRC9iaErzMxExFGu0xrAfwCMZeYavfnCiWgcgHEA0L59e13nCoIgCJGJqQiYeVikfUS0m4haM/NOpaHfE+G4xgCmA5jIzD8om/cDaEpEKcqooC2AoihyTAEwRbleCRGtjSW7gzQHsM9pIWLgdhndLh8gMpqB2+UD3C+jHvk6qG00Glk8DcBYAJOVv5+EH6B4An0M4A1mDs4HBEcQ8wGMAfBOpPMjsJaZ8w3KbhlEVOhm+QD3y+h2+QCR0QzcLh/gfhnNkM/oHMFkAOcR0XoAw5R1EFE+Eb2sHHMpgLMAXENES5RPH2XfnQD+QkQbEJgzeMWgPIIgCIJODI0ImHk/gKEq2wsB3KAsvwngzQjnbwIwwIgMgiAIgjG8Glk8xWkBYuB2+QD3y+h2+QCR0QzcLh/gfhkNy0fMER19BEEQBB/g1RGBIAiCYBKeUgREVEBEa5XcRGrpLFwlDxFdQ0R7QybJb3BCzjCZXiWiPUS0wmlZgNjyENEQIjoU8h3ea7eMKjK1I6L5RLRKyaH1JzfL4tLvMIOIfiSipYrcD7hZFje+y0GIKJmIfiaiz+K+CDN74gMgGYGI5E4IRCcvBdDTzfIAuAbAc05/d2EynQXgVAArnJZFizwAhgD4zGk5w2RqDeBUZTkLwDqnnkUtsrj0OyQAjZTlVAALAZzmVlnc+C6HyPYXAG8b+Y29NCIYAGADM29i5goEYg9Gizz6YOavARQ7LUcQt8mjBWbeycw/KcslAFbDofQobpJFDxzgiLKaqnwcmbB0kyx6IaK2AEYCeDnWsdHwkiJoA2B7yLrTuYm0yvMbJf32B0TUzh7REo5ByrD9cyLq5bQwoRBRHoC+CPQiHSWGLK77DhWTxhIEMhLMZmbHvkONsrjxXX4KwB0AaoxcxEuKwIt8CiCPmU8GMBvHM7UK2vkJQAdmPgXAswD+56w4xyGiRgA+BHAbMx92sSyu/A6ZuZqZ+yCQXmYAEfV2sSyue5eJ6EIAe5h5sdFreUkRFAEI1cJRcxPZQEx5mHk/M5crqy8D6GeTbAkDMx8ODtuZeQaAVCJq7rBYIKJUBBret5j5IzfL4tbvMAgzHwQwH4F6JY4SSRaXvstnABhFRFsQME2fS0Sqwbux8JIiWASgKwWqmqUBuByBXEeulUdJxBdkFAL2W0EHRNSKlFS1FEhjnoRAwkInZSIE0qGsZuYn3C6LS7/DXFIKVBFRAwDnAVjjVlnc+C4z8wRmbsvMeQi0P/OY+cp4rmU06ZxtMHMVEd0KYBYCHjuvMvNKt8lDRJMAFDLzNAB/JKJRAKoQmBC9xil5gxDRfxHwImlORDsA3MfMjuV4UpMHgck6MPOLCCQlvIWIqgAcBXA5K64SDnIGgKsALFfsygBwl9LbdoUsANoDrv4OWwOYSkTJCCim95g5fvdHC2Rx+7tsJhJZLAiC4HO8ZBoSBEEQLEAUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPkcUgSBEgYiahWSc3EVERcryESJ6wWn5BMEMxH1UEDRCRPcDOMLM/3RaFkEwExkRCEIcKDn+P1OW7yeiqUT0DRFtJaKLiegfRLSciGYqKSBARP2I6CsiWkxEs8KiVQXBMUQRCII5dAZwLgLpB94EMJ+ZT0IgknekogyeBTCGmfsBeBXAw04JKwiheCbFhCC4nM+ZuZKIliOQcmSmsn05gDwA3QD0BjBbSfuTDGCnA3IKQj1EEQiCOZQDADPXEFFlSC6fGgTeMwKwkpkHOSWgIERCTEOCYA9rAeQS0SAgkDraLQViBEEUgSDYgFLOdAyAvxPRUgBLAJzuqFCCoCDuo4IgCD5HRgSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPgcUQSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPic/wcvziJ0eY2VRAAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "librosa.display.waveplot(samples_out, sr=sr)" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "special-delicious", - "metadata": {}, - "outputs": [], - "source": [ - "import getpass" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "seasonal-consensus", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['GetPassWarning',\n", - " '__all__',\n", - " '__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " '_raw_input',\n", - " 'contextlib',\n", - " 'fallback_getpass',\n", - " 'getpass',\n", - " 'getuser',\n", - " 'io',\n", - " 'os',\n", - " 'sys',\n", - " 'termios',\n", - " 'unix_getpass',\n", - " 'warnings',\n", - " 'win_getpass']" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(getpass)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "dress-distinction", - "metadata": {}, - "outputs": [], - "source": [ - "getpass?" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "rental-anthony", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Worker:" - ] - } - ], - "source": [ - "import multiprocessing\n", - "import cProfile\n", - "import time\n", - "\n", - "def worker(num):\n", - " time.sleep(3)\n", - " print('Worker:', num)\n", - "\n", - "def profile_worker(num):\n", - " cProfile.runctx('worker(num)', globals(), locals(), 'profile-%d.out' %num)\n", - "\n", - "\n", - "\n", - "for i in range(5):\n", - " p = multiprocessing.Process(target=profile_worker, args=(i,))\n", - " p.start()" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "separated-restriction", - "metadata": {}, - "outputs": [], - "source": [ - "!ls" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "painted-variable", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(2, 2)\n", - "[ 1 20]\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "l = [(1, 20), (2, 30)]\n", - "scores = np.array(l)\n", - "print(scores.shape)\n", - "print(scores[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "satellite-insider", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0 1]\n" - ] - } - ], - "source": [ - "sort_idx = np.argsort(scores[:, -1])\n", - "print(sort_idx)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "developed-thirty", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 1 20]\n", - " [ 2 30]]\n" - ] - } - ], - "source": [ - "sorted_val_scores = scores[sort_idx][::1]\n", - "print(sorted_val_scores)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "official-bench", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 1 20]\n", - " [ 2 30]]\n" - ] - } - ], - "source": [ - "sorted_val_scores = scores[sort_idx]\n", - "print(sorted_val_scores)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "ranking-camera", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "b'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x14\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x1e\\x00\\x00\\x00\\x00\\x00\\x00\\x00'\n", - "[ 1 20 2 30]\n", - "[[ 1 20]\n", - " [ 2 30]]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: DeprecationWarning: tostring() is deprecated. Use tobytes() instead.\n", - " \"\"\"Entry point for launching an IPython kernel.\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:3: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead\n", - " This is separate from the ipykernel package so we can avoid doing imports until\n" - ] - } - ], - "source": [ - "a = scores.tostring()\n", - "print(a)\n", - "b = np.fromstring(a, scores.dtype)\n", - "print(b)\n", - "print(scores)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "breeding-proxy", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "numpy.int16" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.int16" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "coordinate-hungary", - "metadata": {}, - "outputs": [], - "source": [ - "dtype = np.dtype('int16')" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "specified-jackson", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "int16\n", - "16\n" - ] - } - ], - "source": [ - "print(dtype)\n", - "dtype is np.int16\n", - "print(np.iinfo(dtype).bits)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "activated-insight", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/train_test.ipynb b/.notebook/train_test.ipynb deleted file mode 100644 index 67212e50a..000000000 --- a/.notebook/train_test.ipynb +++ /dev/null @@ -1,1887 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "cloudy-glass", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "os.environ['CUDA_VISISBLE_DEVICES'] = '0'" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "grand-stephen", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.0.0\n" - ] - } - ], - "source": [ - "import paddle\n", - "print(paddle.__version__)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "isolated-prize", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "romance-samuel", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('infer_manifest', str,\n", - " 'examples/aishell/data/manifest.dev',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " 'examples/aishell/data/mean_std.npz',\n", - " \"Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/aishell/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/aishell/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'linear',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc'])\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "timely-bikini", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " long_ = _make_signed(np.long)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " ulong = _make_unsigned(np.long)\n" - ] - } - ], - "source": [ - "from data_utils.dataset import create_dataloader\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " augmentation_config='{}',\n", - " #max_duration=float('inf'),\n", - " max_duration=27.0,\n", - " min_duration=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=False,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "organized-warrior", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " if arr.dtype == np.object:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[14 , 34 , 322 , 233 , 0 , 0 ],\n", - " [238 , 38 , 122 , 164 , 0 , 0 ],\n", - " [8 , 52 , 49 , 42 , 0 , 0 ],\n", - " [109 , 47 , 146 , 193 , 210 , 479 ],\n", - " [3330, 1751, 208 , 1923, 0 , 0 ]])\n", - "test raw 大时代里的的\n", - "test raw 煲汤受宠的的\n", - "audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [163, 167, 180, 186, 186])\n", - "test len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [4, 4, 4, 6, 4])\n", - "audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n", - " [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n", - " [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n", - " [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n", - " [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n", - " [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n", - " [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n", - " [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n", - " [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n", - " [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n", - " [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n", - " [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n", - " [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n", - " [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n", - " [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n", - " ...,\n", - " [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n", - " [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n", - " [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n", - "\n", - " [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n", - " [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n", - " [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n", - " ...,\n", - " [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n", - " [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n", - " [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n" - ] - } - ], - "source": [ - " for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test', text)\n", - " print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[0]))\n", - " print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[-1]))\n", - " print('audio len', audio_len)\n", - " print('test len', text_len)\n", - " print('audio', audio)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "confidential-radius", - "metadata": {}, - "outputs": [], - "source": [ - "# reader = batch_reader()\n", - "# audio, test , audio_len, text_len = reader.next()\n", - "# print('test', text)\n", - "# print('t len', text_len) #[B, T]\n", - "# print('audio len', audio_len)\n", - "# print(audio)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "future-vermont", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "煲汤受宠\n" - ] - } - ], - "source": [ - "print(u'\\u7172\\u6c64\\u53d7\\u5ba0')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dental-sweden", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "sunrise-contact", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "hispanic-asthma", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "hearing-leadership", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "skilled-friday", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "copyrighted-measure", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "employed-lightweight", - "metadata": {}, - "outputs": [], - "source": [ - "from model_utils.network import DeepSpeech2, DeepSpeech2Loss\n", - "\n", - "from data_utils.dataset import create_dataloader\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " augmentation_config='{}',\n", - " #max_duration=float('inf'),\n", - " max_duration=27.0,\n", - " min_duration=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=False,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)\n", - "\n", - "\n", - "import paddle\n", - "from paddle import nn\n", - "from paddle.nn import functional as F\n", - "from paddle.nn import initializer as I\n", - "\n", - "import math\n", - "\n", - "def brelu(x, t_min=0.0, t_max=24.0, name=None):\n", - " t_min = paddle.to_tensor(t_min)\n", - " t_max = paddle.to_tensor(t_max)\n", - " return x.maximum(t_min).minimum(t_max)\n", - "\n", - "def sequence_mask(x_len, max_len=None, dtype='float32'):\n", - " max_len = max_len or x_len.max()\n", - " x_len = paddle.unsqueeze(x_len, -1)\n", - " row_vector = paddle.arange(max_len)\n", - " mask = row_vector > x_len # maybe a bug\n", - " mask = paddle.cast(mask, dtype)\n", - " print(f'seq mask: {mask}')\n", - " return mask\n", - "\n", - "\n", - "class ConvBn(nn.Layer):\n", - " def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,\n", - " padding, act):\n", - "\n", - " super().__init__()\n", - " self.kernel_size = kernel_size\n", - " self.stride = stride\n", - " self.padding = padding\n", - "\n", - " self.conv = nn.Conv2D(\n", - " num_channels_in,\n", - " num_channels_out,\n", - " kernel_size=kernel_size,\n", - " stride=stride,\n", - " padding=padding,\n", - " weight_attr=None,\n", - " bias_attr=None,\n", - " data_format='NCHW')\n", - "\n", - " self.bn = nn.BatchNorm2D(\n", - " num_channels_out,\n", - " weight_attr=None,\n", - " bias_attr=None,\n", - " data_format='NCHW')\n", - " self.act = F.relu if act == 'relu' else brelu\n", - "\n", - " def forward(self, x, x_len):\n", - " \"\"\"\n", - " x(Tensor): audio, shape [B, C, D, T]\n", - " \"\"\"\n", - " x = self.conv(x)\n", - " x = self.bn(x)\n", - " x = self.act(x)\n", - "\n", - " x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]\n", - " ) // self.stride[1] + 1\n", - "\n", - " # reset padding part to 0\n", - " masks = sequence_mask(x_len) #[B, T]\n", - " masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]\n", - " x = x.multiply(masks)\n", - "\n", - " return x, x_len\n", - "\n", - "\n", - "class ConvStack(nn.Layer):\n", - " def __init__(self, feat_size, num_stacks):\n", - " super().__init__()\n", - " self.feat_size = feat_size # D\n", - " self.num_stacks = num_stacks\n", - "\n", - " self.conv_in = ConvBn(\n", - " num_channels_in=1,\n", - " num_channels_out=32,\n", - " kernel_size=(41, 11), #[D, T]\n", - " stride=(2, 3),\n", - " padding=(20, 5),\n", - " act='brelu')\n", - "\n", - " out_channel = 32\n", - " self.conv_stack = nn.Sequential([\n", - " ConvBn(\n", - " num_channels_in=32,\n", - " num_channels_out=out_channel,\n", - " kernel_size=(21, 11),\n", - " stride=(2, 1),\n", - " padding=(10, 5),\n", - " act='brelu') for i in range(num_stacks - 1)\n", - " ])\n", - "\n", - " # conv output feat_dim\n", - " output_height = (feat_size - 1) // 2 + 1\n", - " for i in range(self.num_stacks - 1):\n", - " output_height = (output_height - 1) // 2 + 1\n", - " self.output_height = out_channel * output_height\n", - "\n", - " def forward(self, x, x_len):\n", - " \"\"\"\n", - " x: shape [B, C, D, T]\n", - " x_len : shape [B]\n", - " \"\"\"\n", - " print(f\"conv in: {x_len}\")\n", - " x, x_len = self.conv_in(x, x_len)\n", - " for i, conv in enumerate(self.conv_stack):\n", - " print(f\"conv in: {x_len}\")\n", - " x, x_len = conv(x, x_len)\n", - " print(f\"conv out: {x_len}\")\n", - " return x, x_len\n", - " \n", - " \n", - "\n", - "class RNNCell(nn.RNNCellBase):\n", - " r\"\"\"\n", - " Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it \n", - " computes the outputs and updates states.\n", - " The formula used is as follows:\n", - " .. math::\n", - " h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})\n", - " y_{t} & = h_{t}\n", - " \n", - " where :math:`act` is for :attr:`activation`.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " hidden_size,\n", - " activation=\"tanh\",\n", - " weight_ih_attr=None,\n", - " weight_hh_attr=None,\n", - " bias_ih_attr=None,\n", - " bias_hh_attr=None,\n", - " name=None):\n", - " super().__init__()\n", - " std = 1.0 / math.sqrt(hidden_size)\n", - " self.weight_hh = self.create_parameter(\n", - " (hidden_size, hidden_size),\n", - " weight_hh_attr,\n", - " default_initializer=I.Uniform(-std, std))\n", - " # self.bias_ih = self.create_parameter(\n", - " # (hidden_size, ),\n", - " # bias_ih_attr,\n", - " # is_bias=True,\n", - " # default_initializer=I.Uniform(-std, std))\n", - " self.bias_ih = None\n", - " self.bias_hh = self.create_parameter(\n", - " (hidden_size, ),\n", - " bias_hh_attr,\n", - " is_bias=True,\n", - " default_initializer=I.Uniform(-std, std))\n", - "\n", - " self.hidden_size = hidden_size\n", - " if activation not in [\"tanh\", \"relu\", \"brelu\"]:\n", - " raise ValueError(\n", - " \"activation for SimpleRNNCell should be tanh or relu, \"\n", - " \"but get {}\".format(activation))\n", - " self.activation = activation\n", - " self._activation_fn = paddle.tanh \\\n", - " if activation == \"tanh\" \\\n", - " else F.relu\n", - " if activation == 'brelu':\n", - " self._activation_fn = brelu\n", - "\n", - " def forward(self, inputs, states=None):\n", - " if states is None:\n", - " states = self.get_initial_states(inputs, self.state_shape)\n", - " pre_h = states\n", - " i2h = inputs\n", - " if self.bias_ih is not None:\n", - " i2h += self.bias_ih\n", - " h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)\n", - " if self.bias_hh is not None:\n", - " h2h += self.bias_hh\n", - " h = self._activation_fn(i2h + h2h)\n", - " return h, h\n", - "\n", - " @property\n", - " def state_shape(self):\n", - " return (self.hidden_size, )\n", - "\n", - "\n", - "class GRUCellShare(nn.RNNCellBase):\n", - " r\"\"\"\n", - " Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, \n", - " it computes the outputs and updates states.\n", - " The formula for GRU used is as follows:\n", - " .. math::\n", - " r_{t} & = \\sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})\n", - " z_{t} & = \\sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})\n", - " \\widetilde{h}_{t} & = \\tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))\n", - " h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \\widetilde{h}_{t}\n", - " y_{t} & = h_{t}\n", - " \n", - " where :math:`\\sigma` is the sigmoid fucntion, and * is the elemetwise \n", - " multiplication operator.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " input_size,\n", - " hidden_size,\n", - " weight_ih_attr=None,\n", - " weight_hh_attr=None,\n", - " bias_ih_attr=None,\n", - " bias_hh_attr=None,\n", - " name=None):\n", - " super().__init__()\n", - " std = 1.0 / math.sqrt(hidden_size)\n", - " self.weight_hh = self.create_parameter(\n", - " (3 * hidden_size, hidden_size),\n", - " weight_hh_attr,\n", - " default_initializer=I.Uniform(-std, std))\n", - " # self.bias_ih = self.create_parameter(\n", - " # (3 * hidden_size, ),\n", - " # bias_ih_attr,\n", - " # is_bias=True,\n", - " # default_initializer=I.Uniform(-std, std))\n", - " self.bias_ih = None\n", - " self.bias_hh = self.create_parameter(\n", - " (3 * hidden_size, ),\n", - " bias_hh_attr,\n", - " is_bias=True,\n", - " default_initializer=I.Uniform(-std, std))\n", - "\n", - " self.hidden_size = hidden_size\n", - " self.input_size = input_size\n", - " self._gate_activation = F.sigmoid\n", - " #self._activation = paddle.tanh\n", - " self._activation = F.relu\n", - "\n", - " def forward(self, inputs, states=None):\n", - " if states is None:\n", - " states = self.get_initial_states(inputs, self.state_shape)\n", - "\n", - " pre_hidden = states\n", - " x_gates = inputs\n", - " if self.bias_ih is not None:\n", - " x_gates = x_gates + self.bias_ih\n", - " h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)\n", - " if self.bias_hh is not None:\n", - " h_gates = h_gates + self.bias_hh\n", - "\n", - " x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)\n", - " h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)\n", - "\n", - " r = self._gate_activation(x_r + h_r)\n", - " z = self._gate_activation(x_z + h_z)\n", - " c = self._activation(x_c + r * h_c) # apply reset gate after mm\n", - " h = (pre_hidden - c) * z + c\n", - " # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru\n", - " #h = (1-z) * pre_hidden + z * c\n", - "\n", - " return h, h\n", - "\n", - " @property\n", - " def state_shape(self):\n", - " r\"\"\"\n", - " The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch\n", - " size would be automatically inserted into shape). The shape corresponds\n", - " to the shape of :math:`h_{t-1}`.\n", - " \"\"\"\n", - " return (self.hidden_size, )\n", - "\n", - "\n", - "class BiRNNWithBN(nn.Layer):\n", - " \"\"\"Bidirectonal simple rnn layer with sequence-wise batch normalization.\n", - " The batch normalization is only performed on input-state weights.\n", - "\n", - " :param name: Name of the layer parameters.\n", - " :type name: string\n", - " :param size: Dimension of RNN cells.\n", - " :type size: int\n", - " :param share_weights: Whether to share input-hidden weights between\n", - " forward and backward directional RNNs.\n", - " :type share_weights: bool\n", - " :return: Bidirectional simple rnn layer.\n", - " :rtype: Variable\n", - " \"\"\"\n", - "\n", - " def __init__(self, i_size, h_size, share_weights):\n", - " super().__init__()\n", - " self.share_weights = share_weights\n", - " if self.share_weights:\n", - " #input-hidden weights shared between bi-directional rnn.\n", - " self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", - " # batch norm is only performed on input-state projection\n", - " self.fw_bn = nn.BatchNorm1D(\n", - " h_size, bias_attr=None, data_format='NLC')\n", - " self.bw_fc = self.fw_fc\n", - " self.bw_bn = self.fw_bn\n", - " else:\n", - " self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", - " self.fw_bn = nn.BatchNorm1D(\n", - " h_size, bias_attr=None, data_format='NLC')\n", - " self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", - " self.bw_bn = nn.BatchNorm1D(\n", - " h_size, bias_attr=None, data_format='NLC')\n", - "\n", - " self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n", - " self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n", - " self.fw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n", - " self.bw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n", - "\n", - " def forward(self, x, x_len):\n", - " # x, shape [B, T, D]\n", - " fw_x = self.fw_bn(self.fw_fc(x))\n", - " bw_x = self.bw_bn(self.bw_fc(x))\n", - " fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n", - " bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n", - " x = paddle.concat([fw_x, bw_x], axis=-1)\n", - " return x, x_len\n", - "\n", - "\n", - "class BiGRUWithBN(nn.Layer):\n", - " \"\"\"Bidirectonal gru layer with sequence-wise batch normalization.\n", - " The batch normalization is only performed on input-state weights.\n", - "\n", - " :param name: Name of the layer.\n", - " :type name: string\n", - " :param input: Input layer.\n", - " :type input: Variable\n", - " :param size: Dimension of GRU cells.\n", - " :type size: int\n", - " :param act: Activation type.\n", - " :type act: string\n", - " :return: Bidirectional GRU layer.\n", - " :rtype: Variable\n", - " \"\"\"\n", - "\n", - " def __init__(self, i_size, h_size, act):\n", - " super().__init__()\n", - " hidden_size = h_size * 3\n", - " self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n", - " self.fw_bn = nn.BatchNorm1D(\n", - " hidden_size, bias_attr=None, data_format='NLC')\n", - " self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n", - " self.bw_bn = nn.BatchNorm1D(\n", - " hidden_size, bias_attr=None, data_format='NLC')\n", - "\n", - " self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n", - " self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n", - " self.fw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n", - " self.bw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n", - "\n", - " def forward(self, x, x_len):\n", - " # x, shape [B, T, D]\n", - " fw_x = self.fw_bn(self.fw_fc(x))\n", - " bw_x = self.bw_bn(self.bw_fc(x))\n", - " fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n", - " bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n", - " x = paddle.concat([fw_x, bw_x], axis=-1)\n", - " return x, x_len\n", - "\n", - "\n", - "class RNNStack(nn.Layer):\n", - " \"\"\"RNN group with stacked bidirectional simple RNN or GRU layers.\n", - "\n", - " :param input: Input layer.\n", - " :type input: Variable\n", - " :param size: Dimension of RNN cells in each layer.\n", - " :type size: int\n", - " :param num_stacks: Number of stacked rnn layers.\n", - " :type num_stacks: int\n", - " :param use_gru: Use gru if set True. Use simple rnn if set False.\n", - " :type use_gru: bool\n", - " :param share_rnn_weights: Whether to share input-hidden weights between\n", - " forward and backward directional RNNs.\n", - " It is only available when use_gru=False.\n", - " :type share_weights: bool\n", - " :return: Output layer of the RNN group.\n", - " :rtype: Variable\n", - " \"\"\"\n", - "\n", - " def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):\n", - " super().__init__()\n", - " self.rnn_stacks = nn.LayerList()\n", - " for i in range(num_stacks):\n", - " if use_gru:\n", - " #default:GRU using tanh\n", - " self.rnn_stacks.append(\n", - " BiGRUWithBN(i_size=i_size, h_size=h_size, act=\"relu\"))\n", - " else:\n", - " self.rnn_stacks.append(\n", - " BiRNNWithBN(\n", - " i_size=i_size,\n", - " h_size=h_size,\n", - " share_weights=share_rnn_weights))\n", - " i_size = h_size * 2\n", - "\n", - " def forward(self, x, x_len):\n", - " \"\"\"\n", - " x: shape [B, T, D]\n", - " x_len: shpae [B]\n", - " \"\"\"\n", - " for i, rnn in enumerate(self.rnn_stacks):\n", - " x, x_len = rnn(x, x_len)\n", - " masks = sequence_mask(x_len) #[B, T]\n", - " masks = masks.unsqueeze(-1) # [B, T, 1]\n", - " x = x.multiply(masks)\n", - " return x, x_len\n", - "\n", - " \n", - "class DeepSpeech2Test(DeepSpeech2):\n", - " def __init__(self,\n", - " feat_size,\n", - " dict_size,\n", - " num_conv_layers=2,\n", - " num_rnn_layers=3,\n", - " rnn_size=256,\n", - " use_gru=False,\n", - " share_rnn_weights=True):\n", - " super().__init__(feat_size,\n", - " dict_size,\n", - " num_conv_layers=2,\n", - " num_rnn_layers=3,\n", - " rnn_size=256,\n", - " use_gru=False,\n", - " share_rnn_weights=True)\n", - " self.feat_size = feat_size # 161 for linear\n", - " self.dict_size = dict_size\n", - "\n", - " self.conv = ConvStack(feat_size, num_conv_layers)\n", - " \n", - "# self.fc = nn.Linear(1312, dict_size + 1)\n", - "\n", - " i_size = self.conv.output_height # H after conv stack\n", - " self.rnn = RNNStack(\n", - " i_size=i_size,\n", - " h_size=rnn_size,\n", - " num_stacks=num_rnn_layers,\n", - " use_gru=use_gru,\n", - " share_rnn_weights=share_rnn_weights)\n", - " \n", - " self.fc = nn.Linear(rnn_size * 2, dict_size + 1)\n", - " \n", - " def infer(self, audio, audio_len):\n", - " # [B, D, T] -> [B, C=1, D, T]\n", - " audio = audio.unsqueeze(1)\n", - "\n", - " # convolution group\n", - " x, audio_len = self.conv(audio, audio_len)\n", - " print('conv out', x.shape)\n", - "\n", - " # convert data from convolution feature map to sequence of vectors\n", - " B, C, D, T = paddle.shape(x)\n", - " x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]\n", - " x = x.reshape([B, T, C * D]) #[B, T, C*D]\n", - " print('rnn input', x.shape)\n", - "\n", - " # remove padding part\n", - " x, audio_len = self.rnn(x, audio_len) #[B, T, D]\n", - " print('rnn output', x.shape)\n", - "\n", - " logits = self.fc(x) #[B, T, V + 1]\n", - "\n", - " #ctcdecoder need probs, not log_probs\n", - " probs = F.softmax(logits)\n", - "\n", - " return logits, probs, audio_len\n", - "\n", - " def forward(self, audio, audio_len, text, text_len):\n", - " \"\"\"\n", - " audio: shape [B, D, T]\n", - " text: shape [B, T]\n", - " audio_len: shape [B]\n", - " text_len: shape [B]\n", - " \"\"\"\n", - " return self.infer(audio, audio_len)\n", - " \n", - "\n", - "feat_dim=161\n", - "\n", - "model = DeepSpeech2Test(\n", - " feat_size=feat_dim,\n", - " dict_size=batch_reader.dataset.vocab_size,\n", - " num_conv_layers=args.num_conv_layers,\n", - " num_rnn_layers=args.num_rnn_layers,\n", - " rnn_size=1024,\n", - " use_gru=args.use_gru,\n", - " share_rnn_weights=args.share_rnn_weights,\n", - " )\n", - "dp_model = model\n", - "#dp_model = paddle.DataParallel(model)\n", - "\n", - "loss_fn = DeepSpeech2Loss(batch_reader.dataset.vocab_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "divided-incentive", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "discrete-conjunction", - "metadata": {}, - "outputs": [], - "source": [ - "audio, audio_len, text, text_len = None, None, None, None\n", - "\n", - "for idx, inputs in enumerate(batch_reader):\n", - " audio, audio_len, text, text_len = inputs\n", - "# print(idx)\n", - "# print('a', audio.shape, audio.place)\n", - "# print('t', text)\n", - "# print('al', audio_len)\n", - "# print('tl', text_len)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "protected-announcement", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "conv in: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [163, 167, 180, 186, 186])\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "conv in: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [55, 56, 60, 62, 62])\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "conv out: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [55, 56, 60, 62, 62])\n", - "conv out [5, 32, 41, 62]\n", - "rnn input [5, 62, 1312]\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", - " return (isinstance(seq, collections.Sequence) and\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "rnn output [5, 62, 2048]\n", - "logits len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [55, 56, 60, 62, 62])\n", - "loss Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [2316.82153320])\n" - ] - } - ], - "source": [ - "outputs = dp_model(audio, audio_len, text, text_len)\n", - "logits, _, logits_len = outputs\n", - "print('logits len', logits_len)\n", - "loss = loss_fn.forward(logits, text, logits_len, text_len)\n", - "print('loss', loss)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "universal-myrtle", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: None\n", - "param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: None\n", - "param grad: fc.bias: shape: [4299] stop_grad: False grad: None\n" - ] - } - ], - "source": [ - "for n, p in dp_model.named_parameters():\n", - " print(\n", - " f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "referenced-double", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: [[[[ 2.1243238 1.696022 3.770659 ... 5.234652 5.4865217\n", - " 4.757795 ]\n", - " [ 2.651376 2.3109848 4.428488 ... 5.353201 8.703288\n", - " 5.1787405 ]\n", - " [ 2.7511077 1.8823049 2.1875212 ... 3.4821286 6.386543\n", - " 3.5026932 ]\n", - " ...\n", - " [ 1.9173846 1.8623551 0.5601456 ... 2.8375719 3.8496673\n", - " 2.359191 ]\n", - " [ 2.3827765 2.497965 1.5914664 ... 2.220721 3.4617734\n", - " 4.829253 ]\n", - " [ 1.6855702 1.5040786 1.8793598 ... 4.0773935 3.176893\n", - " 3.7477999 ]]]\n", - "\n", - "\n", - " [[[ 1.8451455 2.0091445 1.5225713 ... 1.524528 0.17764974\n", - " 1.0245132 ]\n", - " [ 1.9388857 1.3873467 2.044691 ... 0.92544 -0.9746763\n", - " -0.41603735]\n", - " [ 2.6814485 2.6096234 1.6802506 ... 1.902397 1.6837387\n", - " -0.96788657]\n", - " ...\n", - " [ 4.3675485 1.9822174 1.1695029 ... 1.4672399 3.2029557\n", - " 2.6364415 ]\n", - " [ 3.2536 1.1792442 -0.5618002 ... 2.101127 1.904225\n", - " 3.3839993 ]\n", - " [ 1.9118482 1.0651072 0.5409893 ... 2.6783593 1.6871439\n", - " 4.1078367 ]]]\n", - "\n", - "\n", - " [[[-4.412424 -1.7111907 -1.7722387 ... -4.3383503 -6.2393785\n", - " -6.139402 ]\n", - " [-2.260428 -1.0250616 -2.0550888 ... -5.353946 -4.29947\n", - " -6.158736 ]\n", - " [-1.4927872 0.7552787 -0.0702923 ... -4.485656 -4.0794134\n", - " -5.416684 ]\n", - " ...\n", - " [ 2.9100134 4.156195 4.357041 ... -3.569804 -1.8634341\n", - " -0.8772557 ]\n", - " [ 1.6895763 3.4314504 4.1192107 ... -1.380024 -2.3234155\n", - " -3.6650617 ]\n", - " [ 2.4190075 1.007498 3.1173465 ... -0.96318084 -3.6175003\n", - " -2.5240796 ]]]\n", - "\n", - "\n", - " ...\n", - "\n", - "\n", - " [[[-0.6865506 -0.60106415 -1.5555015 ... 2.0853553 1.900961\n", - " 2.101063 ]\n", - " [-0.31686288 -1.4362946 -1.4929098 ... 0.15085456 1.4540495\n", - " 1.4128599 ]\n", - " [-0.57852304 -0.8204216 -2.3264258 ... 1.4970423 0.54599845\n", - " 1.6222539 ]\n", - " ...\n", - " [ 0.32624918 0.96004546 -0.7476514 ... 2.2786083 2.1000178\n", - " 2.7494807 ]\n", - " [-1.6967826 -0.78979015 -1.8424999 ... 1.0620685 2.0544293\n", - " 2.2483966 ]\n", - " [ 0.8192332 2.601636 -2.6636481 ... 0.26625186 1.7610842\n", - " 1.7467536 ]]]\n", - "\n", - "\n", - " [[[ 0.9140297 0.42424175 1.4352363 ... -2.3022954 -3.001058\n", - " -2.6987422 ]\n", - " [ 0.4491998 -0.10698095 1.5089144 ... -3.2831016 -3.6055021\n", - " -3.6595795 ]\n", - " [ 2.6818252 -1.5750014 -0.34812498 ... -4.4137015 -4.250422\n", - " -3.481941 ]\n", - " ...\n", - " [ 1.4232106 2.9689102 3.9547806 ... -0.481165 0.28190404\n", - " -1.2167063 ]\n", - " [ 2.2297084 4.8198485 4.2857304 ... 0.57483846 1.4093391\n", - " 0.0715822 ]\n", - " [ 1.679745 4.768068 5.416195 ... 0.17254728 0.4623217\n", - " 1.4772662 ]]]\n", - "\n", - "\n", - " [[[-2.0860114 -2.9508173 -1.4945896 ... -4.067145 -2.5652342\n", - " -3.5771027 ]\n", - " [-2.697845 -1.9273603 -2.3885014 ... -2.196533 -2.8573706\n", - " -2.0113711 ]\n", - " [-2.413383 -2.7204053 -1.0502659 ... -3.001385 -3.36447\n", - " -4.3225455 ]\n", - " ...\n", - " [ 1.2754489 0.9560999 1.5239805 ... -0.0105865 -1.00876\n", - " 2.6247358 ]\n", - " [ 1.1965859 1.0378222 1.1025598 ... -0.5394704 0.49838027\n", - " -0.9618193 ]\n", - " [ 1.1361816 1.3232857 0.687318 ... -0.23925456 -0.43679112\n", - " -0.79297894]]]]\n", - "param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: [ 5.9604645e-07 -3.9339066e-06 -1.0728836e-06 -1.6689301e-06\n", - " 1.1920929e-06 -2.5033951e-06 -2.3841858e-07 4.7683716e-07\n", - " 4.2915344e-06 -1.9073486e-06 -1.9073486e-06 3.0994415e-06\n", - " -2.6822090e-06 3.3378601e-06 -4.2915344e-06 5.2452087e-06\n", - " 3.8146973e-06 2.3841858e-07 7.1525574e-07 -3.6954880e-06\n", - " 2.0563602e-06 -2.6226044e-06 3.0994415e-06 -3.5762787e-07\n", - " -4.7683716e-06 1.2218952e-06 3.3378601e-06 -2.5629997e-06\n", - " 2.3841858e-07 -1.7881393e-06 4.7683716e-07 -2.7418137e-06]\n", - "param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: [ 2.363316 3.286464 1.9607866 -1.6367784 -1.6325372 -1.7729434\n", - " -0.9261875 2.0950415 0.1155543 -0.8857083 0.70079553 0.33920464\n", - " 2.6953902 -0.64524114 0.8845749 -1.2271115 0.6578167 -2.939814\n", - " 5.5728893 -1.0917969 0.01470797 1.395206 4.8009634 -0.744532\n", - " 0.944651 -1.092311 1.4877632 -3.042566 0.51686054 -5.4768667\n", - " -5.628145 -1.0894046 ]\n", - "param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: [ 1.5193373 1.8838218 3.7722278 0.28052303 0.5386534 -0.44620085\n", - " -1.6977876 3.115642 0.03312349 -2.9121587 3.8925257 0.2288351\n", - " -2.273387 -1.3597974 4.3708124 -0.23374033 0.116272 -0.7064927\n", - " 6.5267463 -1.5318865 1.0288429 0.7928574 -0.24655592 -2.1116853\n", - " 2.922772 -3.3462617 1.7016437 -3.5471547 0.29777628 -3.2820854\n", - " -4.116946 -0.9909375 ]\n", - "param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: [[[[ 6.20494843e-01 5.95983505e-01 -1.48909020e+00 ... -6.86620831e-01\n", - " 6.71104014e-01 -1.95339048e+00]\n", - " [-3.91837955e-03 1.27062631e+00 -1.63248098e+00 ... 1.07290137e+00\n", - " -9.42245364e-01 -3.34277248e+00]\n", - " [ 2.41821265e+00 2.36212373e-01 -1.84433365e+00 ... 1.23182368e+00\n", - " 1.36039746e+00 -2.94621849e+00]\n", - " ...\n", - " [ 1.55153418e+00 7.25861669e-01 2.08785534e+00 ... -6.40172660e-01\n", - " -3.23889256e-02 -2.30832791e+00]\n", - " [ 3.69824195e+00 1.27163112e-01 4.09263194e-01 ... -8.60729575e-01\n", - " -3.51897454e+00 -2.10093403e+00]\n", - " [-4.94779050e-01 -3.74262631e-01 -1.19801068e+00 ... -2.05930543e+00\n", - " -7.38576293e-01 -9.44581270e-01]]\n", - "\n", - " [[-2.04341412e+00 -3.70606273e-01 -1.40429378e+00 ... -1.71711946e+00\n", - " -4.09437418e-01 -1.74107194e+00]\n", - " [-8.72247815e-01 -1.06301677e+00 -9.19306517e-01 ... -2.98976970e+00\n", - " -3.03250861e+00 -2.37099743e+00]\n", - " [-5.00457406e-01 -1.11882675e+00 -5.91526508e-01 ... 4.23921436e-01\n", - " -2.08650708e+00 -1.82109618e+00]\n", - " ...\n", - " [ 2.07773042e+00 1.40735030e-01 -2.60543615e-01 ... -1.55956164e-01\n", - " -1.31862307e+00 -2.07174897e+00]\n", - " [ 7.95007765e-01 1.14988625e-01 -1.43308258e+00 ... 8.29253554e-01\n", - " -9.57888126e-01 -3.82121086e-01]\n", - " [ 8.34397674e-02 1.38636863e+00 -1.21593380e+00 ... -2.65783578e-01\n", - " 1.78124309e-02 -3.40287232e+00]]\n", - "\n", - " [[ 6.27344131e-01 5.71699142e-02 -3.58010936e+00 ... -4.53077674e-01\n", - " 1.65331578e+00 2.58466601e-02]\n", - " [ 2.66681361e+00 2.02069378e+00 -1.52052927e+00 ... 2.94914508e+00\n", - " 1.94632411e+00 -1.06698799e+00]\n", - " [ 1.57839453e+00 -1.03649735e-01 -4.22528505e+00 ... 2.28863955e+00\n", - " 4.27859402e+00 3.66381669e+00]\n", - " ...\n", - " [-2.44603205e+00 -2.09621000e+00 -2.57623529e+00 ... 9.00211930e-01\n", - " 4.30536079e+00 -2.49779320e+00]\n", - " [-2.52187514e+00 -3.36546659e+00 -1.26748765e+00 ... 8.11533451e-01\n", - " 2.55930424e-01 4.50821817e-02]\n", - " [-3.40082574e+00 -3.26924801e+00 -5.86932135e+00 ... -1.18203712e+00\n", - " 1.09565187e+00 -4.96661961e-01]]\n", - "\n", - " ...\n", - "\n", - " [[ 8.20469666e+00 6.96195841e+00 2.73753977e+00 ... 8.34498823e-01\n", - " 2.56748104e+00 1.67592216e+00]\n", - " [ 9.85801792e+00 8.81465149e+00 6.09280396e+00 ... 1.42389655e+00\n", - " 2.92086434e+00 2.08308399e-01]\n", - " [ 8.00702763e+00 7.97301006e+00 4.64527416e+00 ... 8.61916900e-01\n", - " 3.55370259e+00 4.75085378e-01]\n", - " ...\n", - " [ 5.61662769e+00 -4.72857296e-01 -1.04519971e-01 ... -4.03000236e-01\n", - " -1.66419971e+00 -1.70375630e-01]\n", - " [ 4.52409792e+00 -3.70670676e-01 4.54190969e-02 ... -8.20453286e-01\n", - " 9.49141383e-02 8.88008535e-01]\n", - " [ 3.27219462e+00 8.93201411e-01 1.94810414e+00 ... -2.86915004e-02\n", - " 1.93200278e+00 8.19505215e-01]]\n", - "\n", - " [[ 5.84066296e+00 6.72855520e+00 5.21399307e+00 ... 4.55058670e+00\n", - " 3.19132543e+00 3.17435169e+00]\n", - " [ 6.04594421e+00 6.88997173e+00 5.00542831e+00 ... 2.23561144e+00\n", - " 2.76059532e+00 4.83479440e-01]\n", - " [ 5.36118126e+00 4.13896275e+00 3.68701124e+00 ... 3.64462805e+00\n", - " 2.80596399e+00 1.52781498e+00]\n", - " ...\n", - " [ 2.87856674e+00 5.84320784e-01 1.74297714e+00 ... 2.83938944e-01\n", - " -2.26546407e-01 -1.18434143e+00]\n", - " [ 2.08510804e+00 1.74915957e+00 1.58637917e+00 ... 6.41967297e-01\n", - " -1.31319761e-01 -3.85830402e-01]\n", - " [ 4.41666174e+00 2.58244562e+00 2.97712159e+00 ... 1.42317235e-01\n", - " 1.68037796e+00 -6.50003672e-01]]\n", - "\n", - " [[ 1.05511594e+00 6.74880028e-01 -7.64639139e-01 ... -2.15282440e-01\n", - " 2.07197094e+00 4.48752761e-01]\n", - " [ 2.12095881e+00 3.44118834e+00 1.61375272e+00 ... -1.18487728e+00\n", - " 1.88659012e+00 1.48252523e+00]\n", - " [ 8.33427787e-01 4.35035896e+00 -3.59877385e-02 ... 8.70242774e-01\n", - " 3.75945044e+00 -3.09408635e-01]\n", - " ...\n", - " [ 5.08510351e+00 4.73114061e+00 1.97346115e+00 ... -2.25924397e+00\n", - " -1.26373076e+00 -1.37826729e+00]\n", - " [ 6.17275095e+00 4.16016817e+00 3.15675950e+00 ... -2.02416754e+00\n", - " 1.50002241e-02 1.84633851e+00]\n", - " [ 7.32995272e+00 5.34601831e+00 4.58857203e+00 ... -1.88874304e+00\n", - " 1.53240371e+00 7.47349262e-02]]]\n", - "\n", - "\n", - " [[[-1.80918843e-01 -2.52616453e+00 -2.78145695e+00 ... 1.44283652e+00\n", - " -1.08945215e+00 4.19084758e-01]\n", - " [-9.66833949e-01 -2.41106153e+00 -3.48886085e+00 ... -1.87193304e-01\n", - " 8.21905077e-01 1.89097953e+00]\n", - " [-1.59118319e+00 -2.56997013e+00 -3.10426521e+00 ... 2.05900550e+00\n", - " -2.78253704e-01 6.96343541e-01]\n", - " ...\n", - " [ 6.66302443e-02 -2.00887346e+00 -3.17550874e+00 ... 7.97579706e-01\n", - " -9.71581042e-02 1.71877682e+00]\n", - " [-8.01679730e-01 -2.02678037e+00 -3.21915555e+00 ... 8.35528374e-01\n", - " -1.15296638e+00 4.35728967e-01]\n", - " [ 1.45292446e-01 -2.15479851e+00 -1.51839817e+00 ... -3.07936192e-01\n", - " -5.39051890e-01 1.13107657e+00]]\n", - "\n", - " [[-2.43341160e+00 -3.35346818e+00 -9.87014294e-01 ... 1.34049034e+00\n", - " 2.95773447e-02 1.27177119e+00]\n", - " [-2.61602497e+00 -9.76761580e-01 -2.52060473e-01 ... -1.38134825e+00\n", - " 3.85564029e-01 4.57195908e-01]\n", - " [-2.23676014e+00 -4.00404739e+00 -2.23409963e+00 ... -1.41846514e+00\n", - " -6.58698231e-02 -3.61778140e-01]\n", - " ...\n", - " [-1.13604403e+00 -6.03917837e-02 -4.95491922e-01 ... 2.14673686e+00\n", - " 1.21484184e+00 2.22764325e+00]\n", - " [-1.05162430e+00 -1.59828448e+00 3.15489501e-01 ... 2.28046751e+00\n", - " 2.39702511e+00 2.43942714e+00]\n", - " [-1.27370405e+00 -2.05736399e-01 -1.12124372e+00 ... 2.21597219e+00\n", - " 2.50086927e+00 1.91134131e+00]]\n", - "\n", - " [[-4.53170598e-01 -1.59644139e+00 -3.63470483e+00 ... -4.35066032e+00\n", - " -3.79540777e+00 -1.09796596e+00]\n", - " [-2.21036464e-01 -2.53353834e+00 -1.28269875e+00 ... -3.38615727e+00\n", - " -2.59143281e+00 7.74220943e-01]\n", - " [-6.89323783e-01 -1.44375205e+00 6.66438341e-02 ... -1.30736077e+00\n", - " -1.23293114e+00 1.58148706e+00]\n", - " ...\n", - " [ 1.63751483e+00 -4.08427984e-01 -8.15176964e-01 ... 3.70807743e+00\n", - " 2.04232907e+00 1.97716308e+00]\n", - " [ 2.13261342e+00 1.85947633e+00 -8.06532025e-01 ... 1.98311245e+00\n", - " 2.27003932e+00 -1.11734614e-01]\n", - " [ 1.28702402e+00 3.98628891e-01 -1.63712263e+00 ... 8.00528765e-01\n", - " 5.78273535e-01 -2.59924948e-01]]\n", - "\n", - " ...\n", - "\n", - " [[ 3.96233416e+00 4.66794682e+00 1.39437711e+00 ... 7.52061129e-01\n", - " -1.53534544e+00 -6.67162359e-01]\n", - " [ 2.33841681e+00 3.35811281e+00 9.80114818e-01 ... 1.48806703e+00\n", - " 2.68609226e-01 -1.35124445e+00]\n", - " [ 2.08177710e+00 4.28519583e+00 1.52450514e+00 ... 7.45321214e-01\n", - " -5.04359961e-01 -1.81241560e+00]\n", - " ...\n", - " [ 2.95398951e-01 4.30877179e-01 -2.03731894e+00 ... -4.20221925e-01\n", - " 3.29260826e-01 5.83679557e-01]\n", - " [ 1.30742240e+00 -6.32183790e-01 -3.13741422e+00 ... 9.63868052e-02\n", - " 2.91730791e-01 1.33400351e-01]\n", - " [ 5.43292165e-01 -2.83665359e-01 -1.88138187e+00 ... 2.15468198e-01\n", - " 4.90157723e-01 2.40562439e+00]]\n", - "\n", - " [[ 1.57632053e+00 6.27885723e+00 2.87853765e+00 ... 3.07016110e+00\n", - " 1.91490650e+00 1.76274943e+00]\n", - " [ 2.57776356e+00 4.07256317e+00 2.52231169e+00 ... 4.09494352e+00\n", - " 2.53548074e+00 2.44395185e+00]\n", - " [ 2.43037057e+00 4.35728836e+00 1.96233964e+00 ... 2.26702976e+00\n", - " 2.94634581e+00 2.21452284e+00]\n", - " ...\n", - " [-2.72509992e-01 -8.41220498e-01 -1.89133918e+00 ... -1.80079627e+00\n", - " -2.00367713e+00 -7.09145784e-01]\n", - " [ 8.21575999e-01 -1.13323164e+00 -2.62418866e+00 ... -2.38889670e+00\n", - " -7.83945560e-01 -1.01922750e-01]\n", - " [-1.14730227e+00 -1.42182577e+00 -2.00993991e+00 ... -2.11025667e+00\n", - " 1.60286129e-02 -7.26446986e-01]]\n", - "\n", - " [[ 4.20389509e+00 3.75917768e+00 4.97653627e+00 ... 1.23642838e+00\n", - " 8.52760911e-01 1.27920091e-01]\n", - " [ 5.29409122e+00 5.29002380e+00 3.96404648e+00 ... 1.91227329e+00\n", - " 3.97556186e-01 1.69182217e+00]\n", - " [ 4.60112572e+00 4.12772799e+00 2.10280085e+00 ... 3.24303842e+00\n", - " -1.07720590e+00 -3.81854475e-01]\n", - " ...\n", - " [ 1.81884170e-02 -3.11472058e+00 -8.23525012e-01 ... -2.40161085e+00\n", - " -4.48192549e+00 -6.14600539e-01]\n", - " [ 1.16305006e+00 -1.15409636e+00 -3.48765063e+00 ... -1.97504926e+00\n", - " -4.44984436e+00 -2.28429958e-01]\n", - " [ 1.29197860e+00 6.17720246e-01 -5.87171853e-01 ... -1.35258228e-01\n", - " -1.29259872e+00 1.30360842e-01]]]\n", - "\n", - "\n", - " [[[-1.26687372e+00 -2.33633637e+00 -1.49625254e+00 ... 2.52396107e+00\n", - " -6.68072224e-01 -1.13282454e+00]\n", - " [-1.34229445e+00 -2.87080932e+00 -2.57388353e+00 ... -8.75385761e-01\n", - " -1.00205469e+00 -3.58956242e+00]\n", - " [-9.49853599e-01 -5.78684711e+00 -3.52962446e+00 ... 8.88233304e-01\n", - " 2.25133196e-01 -1.02802217e+00]\n", - " ...\n", - " [-7.38113701e-01 -3.47510982e+00 -3.23011065e+00 ... -1.25624001e+00\n", - " -1.63268471e+00 6.00247443e-01]\n", - " [-2.29733467e+00 -5.72547615e-01 -1.98301303e+00 ... -1.90137398e+00\n", - " -1.47013855e+00 -1.45779204e+00]\n", - " [-2.24628520e+00 -3.36337948e+00 -3.91878939e+00 ... -1.53652275e+00\n", - " -1.36285520e+00 -1.68160331e+00]]\n", - "\n", - " [[-8.11348319e-01 -7.17824280e-01 -1.02243233e+00 ... -2.69050407e+00\n", - " -2.32403350e+00 -4.25943947e+00]\n", - " [-2.35056520e+00 -2.35941172e+00 -1.24398732e+00 ... -2.08313870e+00\n", - " -1.16508257e+00 -1.30353463e+00]\n", - " [-2.25146723e+00 -1.94972813e+00 -1.13295293e+00 ... -2.61496377e+00\n", - " -1.91106403e+00 -1.07801402e+00]\n", - " ...\n", - " [-2.67012739e+00 -3.20916414e+00 -2.41768575e+00 ... 2.65138328e-01\n", - " -5.27612507e-01 1.44604075e+00]\n", - " [-3.54237866e+00 -3.62832785e+00 -2.40270257e+00 ... -9.76106226e-02\n", - " 4.67946082e-01 -7.24248111e-01]\n", - " [-2.49844384e+00 -3.42463255e+00 -2.99040008e+00 ... 4.28889185e-01\n", - " -7.51657963e-01 -1.00530767e+00]]\n", - "\n", - " [[-8.42589438e-02 1.42022014e-01 -8.51281703e-01 ... 4.21745628e-01\n", - " -2.35717297e-02 -1.71374834e+00]\n", - " [-1.05496287e+00 3.82416457e-01 -4.40595537e-01 ... 1.03381336e-01\n", - " -1.41204190e+00 -7.58325040e-01]\n", - " [-2.28930283e+00 -2.03857040e+00 -9.16261196e-01 ... -3.94939929e-01\n", - " -1.07798588e+00 -1.48433352e+00]\n", - " ...\n", - " [-3.11473966e-01 -1.40877593e+00 -2.42908645e+00 ... 7.88682699e-01\n", - " 1.24199319e+00 1.89949930e-01]\n", - " [ 5.44084549e-01 -1.02425671e+00 -1.53991556e+00 ... -4.36764538e-01\n", - " -5.78772545e-01 2.62665659e-01]\n", - " [ 1.26812792e+00 -9.89493608e-01 -1.47972977e+00 ... 2.21440494e-02\n", - " 2.79776216e-01 7.63269484e-01]]\n", - "\n", - " ...\n", - "\n", - " [[ 6.02095068e-01 5.93243122e-01 -1.06838238e+00 ... 3.56546330e+00\n", - " 1.16390383e+00 -1.47593319e-01]\n", - " [ 1.80458140e+00 1.68401957e+00 4.17516947e-01 ... 3.33444500e+00\n", - " 1.89411759e+00 1.03220642e-01]\n", - " [ 2.74264169e+00 2.92038846e+00 1.00775683e+00 ... 3.53285050e+00\n", - " 2.07282662e+00 -2.56800652e-01]\n", - " ...\n", - " [ 4.88933468e+00 3.72433925e+00 3.58677816e+00 ... 1.98363388e+00\n", - " 1.80851030e+00 8.32634747e-01]\n", - " [ 4.01546288e+00 4.78934765e+00 2.94778132e+00 ... 2.99637699e+00\n", - " 1.30439472e+00 3.61029744e-01]\n", - " [ 3.13628030e+00 2.01894832e+00 2.82585931e+00 ... 2.54264188e+00\n", - " -9.16651785e-02 9.93353873e-02]]\n", - "\n", - " [[ 2.35585642e+00 8.42678428e-01 1.57331872e+00 ... 3.65935063e+00\n", - " 3.94066262e+00 4.89832020e+00]\n", - " [ 1.85791731e+00 1.34373701e+00 1.30812299e+00 ... 2.71434736e+00\n", - " 3.22004294e+00 2.99872303e+00]\n", - " [ 1.67675853e+00 -4.05569375e-02 1.85539150e+00 ... 3.73934364e+00\n", - " 2.98195982e+00 3.37315011e+00]\n", - " ...\n", - " [ 2.14539170e+00 2.86586595e+00 2.20222116e+00 ... 1.20492995e+00\n", - " 2.13971066e+00 1.94932449e+00]\n", - " [ 4.68422651e+00 3.80044746e+00 4.23209000e+00 ... 2.40658951e+00\n", - " 2.29117441e+00 2.52368808e+00]\n", - " [ 3.10694575e+00 2.49402595e+00 4.53786707e+00 ... 9.08902645e-01\n", - " 1.86903965e+00 2.27776885e+00]]\n", - "\n", - " [[ 1.45200038e+00 5.17961740e-01 -1.58403587e+00 ... 5.07019472e+00\n", - " 7.87163258e-01 1.20610237e+00]\n", - " [ 3.39321136e+00 2.21043849e+00 -6.31202877e-01 ... 4.97822762e+00\n", - " 9.66498017e-01 1.18883348e+00]\n", - " [ 1.20627856e+00 1.82759428e+00 5.91053367e-01 ... 4.14318657e+00\n", - " 5.25399208e-01 -1.16850233e+00]\n", - " ...\n", - " [ 1.05183899e+00 5.80030501e-01 1.89724147e+00 ... 2.54626465e+00\n", - " -1.49128008e+00 -1.85064209e+00]\n", - " [ 1.50983357e+00 2.85973406e+00 2.61224055e+00 ... 4.83481932e+00\n", - " 9.67048705e-02 -4.37043965e-01]\n", - " [ 2.57720876e+00 2.09961963e+00 4.11754288e-02 ... 3.80421424e+00\n", - " -7.83308804e-01 -1.64871216e+00]]]\n", - "\n", - "\n", - " ...\n", - "\n", - "\n", - " [[[-1.16345096e+00 -2.53971386e+00 -8.99101734e-01 ... -4.35583591e-01\n", - " -1.29671764e+00 -1.61429560e+00]\n", - " [ 3.72841507e-01 3.45808208e-01 -1.82167351e+00 ... -2.14515448e+00\n", - " -1.26383066e+00 -2.27464601e-01]\n", - " [ 1.58568513e+00 2.58181524e+00 1.86554670e+00 ... -1.10401320e+00\n", - " -3.68550658e-01 -2.58849680e-01]\n", - " ...\n", - " [-9.15827155e-01 -1.25424683e+00 -4.04716206e+00 ... 2.13138080e+00\n", - " 2.67662477e+00 2.31014514e+00]\n", - " [-3.19453120e-01 -6.71132684e-01 -1.51378751e+00 ... 1.86080432e+00\n", - " 2.77418542e+00 1.22875953e+00]\n", - " [-1.20453942e+00 -3.93669218e-01 -1.51751983e+00 ... 1.17620552e+00\n", - " 1.95602298e+00 7.64306366e-01]]\n", - "\n", - " [[-8.73186827e-01 -2.12537169e+00 -1.91664994e+00 ... -2.90821463e-01\n", - " 1.90896463e+00 8.02283168e-01]\n", - " [-1.06389821e+00 -2.15300727e+00 -1.82113051e+00 ... -4.34280694e-01\n", - " 1.53455496e+00 1.94702053e+00]\n", - " [-2.08403468e+00 -4.72900331e-01 -1.10610819e+00 ... -8.79420400e-01\n", - " 7.79394627e-01 2.02670670e+00]\n", - " ...\n", - " [-4.28208113e-01 -7.90894389e-01 -1.06713009e+00 ... 1.12579381e+00\n", - " 9.61961091e-01 1.40342009e+00]\n", - " [ 4.40416574e-01 -1.65901780e-02 -1.05338669e+00 ... 1.40698349e+00\n", - " 9.43485856e-01 2.34856772e+00]\n", - " [-1.20572495e+00 -2.03134632e+00 4.88817632e-01 ... 2.20770907e+00\n", - " 1.38143206e+00 2.00714707e+00]]\n", - "\n", - " [[ 9.00486887e-01 -9.50459957e-01 -1.42935121e+00 ... -1.30648065e+00\n", - " -2.52133775e+00 -8.87715697e-01]\n", - " [ 3.73431134e+00 1.69571114e+00 5.99429727e-01 ... 6.64332986e-01\n", - " -6.10453069e-01 2.06534386e+00]\n", - " [ 1.59800696e+00 -4.59622175e-01 -6.73136234e-01 ... 2.18770742e-01\n", - " -1.12928271e+00 4.87097502e-02]\n", - " ...\n", - " [ 1.92336845e+00 1.37130380e-01 -3.51048648e-01 ... 5.41638851e-01\n", - " 1.06069386e+00 1.36404145e+00]\n", - " [ 1.29641414e+00 -2.79530913e-01 -2.63607264e-01 ... -8.62445176e-01\n", - " 1.48393130e+00 2.69196725e+00]\n", - " [ 1.14442182e+00 -1.24098969e+00 3.70959163e-01 ... -1.12241995e+00\n", - " 3.67927134e-01 2.55976987e+00]]\n", - "\n", - " ...\n", - "\n", - " [[ 5.32017851e+00 3.64207411e+00 3.84571218e+00 ... 3.60754800e+00\n", - " 2.57500267e+00 -1.38083458e-01]\n", - " [ 5.69058084e+00 3.93056583e+00 2.93337941e+00 ... 3.17091584e+00\n", - " 2.34770632e+00 6.48133337e-01]\n", - " [ 5.98239613e+00 6.16548634e+00 3.04750896e+00 ... 5.51510525e+00\n", - " 4.34810448e+00 1.31588542e+00]\n", - " ...\n", - " [ 5.09930992e+00 3.32360983e+00 2.29228449e+00 ... 3.45123887e-01\n", - " 1.06280947e+00 -5.93325794e-02]\n", - " [ 4.19760656e+00 3.97779059e+00 1.66905916e+00 ... 3.68937254e-01\n", - " 8.06131065e-02 8.08142900e-01]\n", - " [ 4.52498960e+00 3.45109749e+00 1.01074433e+00 ... -2.54036248e-01\n", - " 3.13675582e-01 2.13851762e+00]]\n", - "\n", - " [[ 6.93927193e+00 6.05758238e+00 4.60648441e+00 ... 4.32221603e+00\n", - " 3.17874146e+00 1.47012353e+00]\n", - " [ 7.88523865e+00 6.62228966e+00 4.77496338e+00 ... 4.45868683e+00\n", - " 2.73698759e+00 2.17057824e+00]\n", - " [ 7.12061214e+00 6.01714134e+00 4.52996492e+00 ... 3.97184372e+00\n", - " 3.43153954e+00 1.21802723e+00]\n", - " ...\n", - " [ 2.85720730e+00 1.89639473e+00 1.96340394e+00 ... 1.89643729e+00\n", - " 1.64856291e+00 1.15853786e+00]\n", - " [ 3.88248491e+00 2.16386199e+00 1.53069091e+00 ... 2.71704245e+00\n", - " 2.24890351e+00 2.22156644e+00]\n", - " [ 5.27136230e+00 1.68400204e+00 2.09500480e+00 ... 2.75956345e+00\n", - " 3.71970820e+00 1.69852686e+00]]\n", - "\n", - " [[ 2.55598164e+00 1.64588141e+00 6.70431674e-01 ... 3.24091220e+00\n", - " 1.48759770e+00 -1.72001183e+00]\n", - " [ 4.33942318e+00 8.40826690e-01 -7.40000725e-01 ... 7.24577069e-01\n", - " 1.74327165e-01 -1.83029580e+00]\n", - " [ 4.39864540e+00 2.28395438e+00 -1.90353513e-01 ... 5.58019161e+00\n", - " 1.05627227e+00 -8.02519619e-01]\n", - " ...\n", - " [ 1.97654784e+00 3.26888156e+00 1.52879453e+00 ... 3.15013933e+00\n", - " 4.66731453e+00 4.98701715e+00]\n", - " [ 1.40016854e+00 3.45761251e+00 3.68359756e+00 ... 1.14207900e+00\n", - " 3.32219076e+00 3.83035636e+00]\n", - " [ 1.99269783e+00 2.15428829e+00 3.35396528e-01 ... 2.45916694e-01\n", - " 2.13785577e+00 4.33214951e+00]]]\n", - "\n", - "\n", - " [[[ 1.35320330e+00 5.05850911e-02 1.04915988e+00 ... 1.82023585e-01\n", - " 2.72914767e-01 3.92112255e-01]\n", - " [ 1.04646444e+00 7.60913491e-01 1.93323612e+00 ... 1.19493449e+00\n", - " -1.44200325e-01 4.07531261e-02]\n", - " [-9.88207340e-01 -1.46165287e+00 1.05884135e-01 ... -3.23057353e-01\n", - " -2.28934169e+00 -7.38609374e-01]\n", - " ...\n", - " [ 1.01198792e+00 2.34331083e+00 1.04566610e+00 ... 1.29697472e-01\n", - " -1.23878837e+00 2.21006930e-01]\n", - " [-3.75360101e-01 1.53673506e+00 -1.32206869e+00 ... -2.55255580e-01\n", - " -6.22699618e-01 -1.73162484e+00]\n", - " [ 4.34735864e-01 5.08327007e-01 -3.49233925e-01 ... -1.04749084e+00\n", - " -1.15777385e+00 -1.13671994e+00]]\n", - "\n", - " [[ 1.67839336e+00 -1.80224836e-01 1.02194118e+00 ... 8.44027162e-01\n", - " 8.81283879e-02 -1.37762165e+00]\n", - " [ 8.39694083e-01 1.32322550e+00 4.02442753e-01 ... -4.21785116e-01\n", - " -9.98012185e-01 -1.11348581e+00]\n", - " [ 7.64424682e-01 8.58965695e-01 2.94626594e-01 ... -6.65519595e-01\n", - " -3.65677416e-01 -2.25250268e+00]\n", - " ...\n", - " [-1.10193872e+00 1.18070498e-01 1.04604781e-01 ... -1.44486964e+00\n", - " -2.52748466e+00 -2.16131711e+00]\n", - " [-1.06079710e+00 -1.48379254e+00 3.80138367e-01 ... -1.62288392e+00\n", - " -2.44736362e+00 -8.78590107e-01]\n", - " [ 3.44401300e-02 -2.60935068e+00 -2.35597759e-01 ... -2.41114974e+00\n", - " -2.45255780e+00 -1.82384634e+00]]\n", - "\n", - " [[ 1.37670958e+00 1.58661580e+00 -2.85664916e-01 ... 1.49081087e+00\n", - " 4.13422853e-01 1.12761199e+00]\n", - " [ 1.54148173e+00 6.22704089e-01 1.41886568e+00 ... 1.59678531e+00\n", - " -8.72656107e-01 1.52415514e-01]\n", - " [ 3.30207205e+00 2.89925170e+00 1.91855145e+00 ... 3.18863559e+00\n", - " 1.87347198e+00 9.48901057e-01]\n", - " ...\n", - " [-1.53920484e+00 1.77375078e-02 -1.02018684e-01 ... 1.94011092e+00\n", - " -6.83587790e-01 1.49154460e+00]\n", - " [-2.27719522e+00 1.02481163e+00 -2.11300224e-01 ... -8.18020821e-01\n", - " 1.54248989e+00 -1.46732473e+00]\n", - " [-4.50206220e-01 3.62383485e+00 1.07175660e+00 ... 4.25961137e-01\n", - " 1.12405360e-01 -6.87821358e-02]]\n", - "\n", - " ...\n", - "\n", - " [[-3.40477467e-01 -2.99311423e+00 -2.12096786e+00 ... 2.27393007e+00\n", - " 4.03424358e+00 3.73335361e+00]\n", - " [-6.99971199e-01 -2.97719741e+00 -2.72910309e+00 ... 1.50101089e+00\n", - " 2.29408574e+00 3.14105940e+00]\n", - " [-1.41648722e+00 -1.86292887e+00 -1.84006739e+00 ... 2.78402638e+00\n", - " 3.91481900e+00 5.32456112e+00]\n", - " ...\n", - " [ 5.97958088e-01 1.50512588e+00 6.23718500e-01 ... 2.83813477e+00\n", - " 3.87909842e+00 3.33359623e+00]\n", - " [ 1.65542316e+00 3.56163192e+00 4.01527691e+00 ... 3.38367462e+00\n", - " 1.55827272e+00 2.50741863e+00]\n", - " [ 2.82036042e+00 2.53322673e+00 4.38798475e+00 ... 4.64642382e+00\n", - " 3.28739667e+00 3.02895570e+00]]\n", - "\n", - " [[-3.47941303e+00 -3.49006844e+00 -2.25583363e+00 ... 1.45181656e-01\n", - " 1.52944064e+00 2.08810711e+00]\n", - " [-2.27786446e+00 -4.59218550e+00 -2.74722624e+00 ... -1.73136210e+00\n", - " 7.46028006e-01 1.74789345e+00]\n", - " [-3.35524082e+00 -4.58244705e+00 -2.40820456e+00 ... -5.04051924e-01\n", - " 1.49640536e+00 2.16613841e+00]\n", - " ...\n", - " [ 5.26107132e-01 2.05329061e+00 2.84252572e+00 ... 1.33222675e+00\n", - " 3.87935114e+00 3.69385266e+00]\n", - " [ 4.38092083e-01 2.15028906e+00 3.13363624e+00 ... 3.36048746e+00\n", - " 5.36551809e+00 2.94915986e+00]\n", - " [ 2.75497317e+00 3.25929213e+00 2.33522987e+00 ... 1.69926262e+00\n", - " 3.93462896e+00 3.68200874e+00]]\n", - "\n", - " [[ 1.10951948e+00 5.31419516e-02 -1.58864903e+00 ... 5.24887085e+00\n", - " 1.60273385e+00 4.90113163e+00]\n", - " [-2.94517064e+00 -2.81092644e+00 -4.89631557e+00 ... 3.99868512e+00\n", - " 1.40544355e+00 2.84833241e+00]\n", - " [-3.51893663e-01 -3.53325534e+00 -2.21239805e+00 ... 4.26225853e+00\n", - " 6.87886119e-01 2.58609629e+00]\n", - " ...\n", - " [ 2.92248201e+00 5.40264511e+00 4.65721560e+00 ... 5.24537373e+00\n", - " 2.30406880e+00 1.29892707e+00]\n", - " [ 1.43473256e+00 4.61167526e+00 3.57578802e+00 ... 5.12181854e+00\n", - " 8.59923482e-01 1.38731599e+00]\n", - " [-6.50881350e-01 2.18233657e+00 2.74669623e+00 ... 4.86368895e+00\n", - " 1.44120216e+00 1.79993320e+00]]]\n", - "\n", - "\n", - " [[[ 1.64106202e+00 3.54410499e-01 -3.54172409e-01 ... 2.32646990e+00\n", - " 1.65043330e+00 3.45897645e-01]\n", - " [ 2.16236949e+00 1.28213906e+00 2.26082468e+00 ... 6.10507369e-01\n", - " 9.12241280e-01 1.27429694e-01]\n", - " [ 2.07962990e+00 7.03816175e-01 2.01272345e+00 ... -2.26959705e-01\n", - " 1.00041127e+00 5.87104559e-02]\n", - " ...\n", - " [-1.62972426e+00 -3.04028845e+00 -1.39124167e+00 ... 2.47561097e+00\n", - " 2.35047388e+00 1.61532843e+00]\n", - " [-1.97368932e+00 -5.44541061e-01 -5.92882216e-01 ... 1.39800012e+00\n", - " 2.32770801e+00 9.96662021e-01]\n", - " [-1.15636075e+00 -1.34654212e+00 -8.50648999e-01 ... 1.85655832e+00\n", - " 2.05776072e+00 5.34575820e-01]]\n", - "\n", - " [[-1.02104437e+00 3.08469892e-01 2.81789303e-01 ... -8.24654043e-01\n", - " -9.85817850e-01 -2.05517030e+00]\n", - " [ 9.50192690e-01 3.35105330e-01 5.31637192e-01 ... -1.42974198e-01\n", - " -1.79659498e+00 -1.58266973e+00]\n", - " [-2.51316994e-01 -1.28709340e+00 3.01498562e-01 ... -1.32253516e+00\n", - " -1.55507576e+00 -9.37123299e-01]\n", - " ...\n", - " [ 2.33016998e-01 2.92454743e+00 3.15420461e+00 ... 1.15574491e+00\n", - " 1.27850962e+00 1.35487700e+00]\n", - " [ 3.81013602e-01 1.44239831e+00 6.64825320e-01 ... -3.89374971e-01\n", - " 1.50716826e-01 1.33641326e+00]\n", - " [ 1.71373415e+00 1.67357373e+00 1.76596940e+00 ... 1.57941079e+00\n", - " 1.60940981e+00 1.78091609e+00]]\n", - "\n", - " [[-5.16522598e+00 -1.68099070e+00 -3.24440050e+00 ... -3.46229005e+00\n", - " -2.18273020e+00 -1.98621082e+00]\n", - " [-3.05743694e+00 9.15392339e-01 -1.93508530e+00 ... -1.82306373e+00\n", - " -2.12960863e+00 -3.45255351e+00]\n", - " [-4.32777822e-01 -1.00303245e+00 -1.61397791e+00 ... -2.08376765e+00\n", - " -3.72989595e-01 -1.36516929e+00]\n", - " ...\n", - " [-5.83641946e-01 4.14125490e+00 1.58227599e+00 ... 2.03144050e+00\n", - " 2.13982654e+00 -1.81909311e+00]\n", - " [-1.74230576e+00 2.39347410e+00 2.44080925e+00 ... 5.43732524e-01\n", - " 2.07899213e+00 -3.71748984e-01]\n", - " [ 3.80016506e-01 7.84988403e-01 1.20596504e+00 ... -2.32057095e+00\n", - " -2.81265080e-01 -3.69353056e+00]]\n", - "\n", - " ...\n", - "\n", - " [[-3.48024845e+00 -2.60937548e+00 -3.84952760e+00 ... 6.68736577e-01\n", - " -1.75104141e-02 -3.54720926e+00]\n", - " [-2.59637117e+00 -5.18190145e+00 -2.33887696e+00 ... 9.13373232e-02\n", - " -3.58282638e+00 -2.40778995e+00]\n", - " [-2.50912881e+00 -1.22113395e+00 -2.34372020e+00 ... 1.40071487e+00\n", - " -1.67449510e+00 -1.14655948e+00]\n", - " ...\n", - " [-5.75253534e+00 -6.67348385e+00 -5.05184650e+00 ... -2.73145151e+00\n", - " -1.48933101e+00 -1.36807609e+00]\n", - " [-3.29049587e+00 -3.73956156e+00 -2.85064268e+00 ... -3.92481357e-01\n", - " -8.00529659e-01 -8.39800835e-01]\n", - " [-4.30351114e+00 -4.21471930e+00 -2.41703367e+00 ... -1.27081513e+00\n", - " 1.67839837e+00 8.47821474e-01]]\n", - "\n", - " [[-5.27856112e-01 -1.09752083e+00 3.39107156e-01 ... 2.00062895e+00\n", - " 8.83528054e-01 2.57416844e-01]\n", - " [-1.58655810e+00 -3.36268663e-01 1.16161990e+00 ... 1.54868484e+00\n", - " 2.38878536e+00 1.84097290e+00]\n", - " [ 5.96052647e-01 2.15484858e-01 1.85280466e+00 ... 2.74587560e+00\n", - " 1.61432290e+00 1.13214278e+00]\n", - " ...\n", - " [-4.57659864e+00 -5.42679739e+00 -4.35204458e+00 ... -1.82452416e+00\n", - " -2.18670201e+00 -3.91811800e+00]\n", - " [-1.32477629e+00 -4.19110394e+00 -3.41308069e+00 ... 1.39622003e-01\n", - " -1.59393203e+00 -9.08105671e-01]\n", - " [-3.60161018e+00 -4.05932713e+00 -2.23674798e+00 ... 9.09647286e-01\n", - " 9.73127842e-01 1.19991803e+00]]\n", - "\n", - " [[ 2.04062796e+00 7.95603275e-01 -1.28833270e+00 ... 4.64749050e+00\n", - " 2.25974560e+00 1.02396965e+00]\n", - " [ 1.68882537e+00 2.63353348e+00 2.53597498e-02 ... 4.69063854e+00\n", - " -4.19382691e-01 2.91669458e-01]\n", - " [ 7.71395087e-01 1.20833695e+00 -2.58601785e-01 ... 1.21794045e+00\n", - " -1.51922226e-01 7.44265199e-01]\n", - " ...\n", - " [-6.66095781e+00 -4.81577682e+00 -5.39921665e+00 ... -2.20548606e+00\n", - " 5.72486281e-01 -4.35207397e-01]\n", - " [-7.51608658e+00 -6.67776871e+00 -3.73199415e+00 ... -1.70327055e+00\n", - " 1.01334639e-02 -3.20627165e+00]\n", - " [-5.73050356e+00 -2.74379373e+00 -3.70248461e+00 ... -1.09794116e+00\n", - " -1.73590891e-02 -1.80156028e+00]]]]\n", - "param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: [-1.4305115e-06 0.0000000e+00 -4.0531158e-06 -1.6689301e-06\n", - " 2.3841858e-07 -7.1525574e-07 1.1920929e-06 1.5497208e-06\n", - " -2.3841858e-07 1.6689301e-06 9.5367432e-07 9.5367432e-07\n", - " -2.6226044e-06 1.1920929e-06 1.3113022e-06 1.9669533e-06\n", - " -4.7683716e-07 1.1920929e-06 -1.6689301e-06 -1.5497208e-06\n", - " -2.2649765e-06 4.7683716e-07 2.3841858e-06 -3.5762787e-06\n", - " 2.3841858e-07 2.1457672e-06 -3.5762787e-07 8.3446503e-07\n", - " -3.5762787e-07 -7.1525574e-07 2.6524067e-06 -1.1920929e-06]\n", - "param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: [-3.7669735 1.5226867 1.759756 4.501629 -2.2077336 0.18411277\n", - " 1.3558264 -1.0269645 3.9628277 3.9300344 -2.80754 1.8462183\n", - " -0.03385968 2.1284049 0.46124816 -4.364863 0.78491163 0.25565645\n", - " -5.3538237 3.2606194 0.79100513 -1.4652673 2.769378 1.2283417\n", - " -4.7466464 -1.3404545 -6.9374166 0.710248 2.0944448 0.4334769\n", - " -0.24313992 0.31392363]\n", - "param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: [-0.6251638 2.833331 0.6993131 3.7106915 -2.262496 0.7390424\n", - " 0.5360477 -2.803875 2.1646228 2.117193 -1.9988279 1.5135905\n", - " -2.0181084 2.6450465 0.06302822 -3.0530102 1.4788482 0.5941844\n", - " -3.1690063 1.8753575 -0.0737313 -2.7806277 -0.04483938 0.16129279\n", - " -1.2960215 -0.38020235 -0.55218065 0.10754502 2.065371 -1.4703183\n", - " -0.40964937 -1.4454535 ]\n", - "param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: [[-0.46178514 0.1095643 0.06441769 ... 0.42020613 -0.34181893\n", - " -0.0658682 ]\n", - " [-0.03619978 0.21653323 0.01727325 ... 0.05731536 -0.37822944\n", - " -0.05464617]\n", - " [-0.32397318 0.04158126 -0.08091418 ... 0.0928297 -0.06518176\n", - " -0.40110156]\n", - " ...\n", - " [-0.2702023 0.05126935 0.11825457 ... 0.0069707 -0.36951366\n", - " 0.37071258]\n", - " [-0.11326203 0.19305304 -0.133317 ... -0.13030824 -0.09068564\n", - " 0.32735693]\n", - " [-0.04543798 0.09902512 -0.10745425 ... -0.06685166 -0.3055201\n", - " 0.0752247 ]]\n", - "param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.07338604 0.64991236 0.5465856 ... 0.507725 0.14061031\n", - " 0.3020359 ]\n", - "param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.41395143 -0.28493872 0.36796764 ... 0.2387953 0.06732331\n", - " 0.16263628]\n", - "param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.09370177 -0.12264141 -0.08237482 ... -0.50241685 -0.149155\n", - " -0.25661892]\n", - " [-0.37426725 0.44987115 0.10685667 ... -0.65946174 -0.4499248\n", - " -0.17545304]\n", - " [-0.03753807 0.33422717 0.12750985 ... 0.05405155 -0.17648363\n", - " 0.05315325]\n", - " ...\n", - " [ 0.15721183 0.03064088 -0.00751081 ... 0.27183983 0.3881693\n", - " -0.01544908]\n", - " [ 0.26047793 0.16917065 0.00915196 ... 0.18076143 -0.05080506\n", - " 0.14791614]\n", - " [ 0.19052255 0.03642382 -0.14313167 ... 0.2611448 0.20763844\n", - " 0.26846847]]\n", - "param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.4139514 -0.28493875 0.36796758 ... 0.23879525 0.06732336\n", - " 0.16263627]\n", - "param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[ 0.04214853 -0.1710323 0.17557406 ... 0.11926915 0.21577051\n", - " -0.30598596]\n", - " [-0.02370887 -0.03498494 -0.05991999 ... -0.06049232 -0.14527473\n", - " -0.5335691 ]\n", - " [-0.21417995 -0.10263194 -0.05903128 ... -0.26958284 0.05936668\n", - " 0.25522667]\n", - " ...\n", - " [ 0.31594425 -0.29487017 0.15871571 ... 0.3504135 -0.1418606\n", - " -0.07482046]\n", - " [ 0.22316164 0.7682122 -0.22191924 ... -0.00535548 -0.6497105\n", - " -0.2011079 ]\n", - " [-0.05800886 0.13750821 0.02450509 ... 0.245736 0.07425706\n", - " -0.17761081]]\n", - "param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.45080703 0.19005743 0.077441 ... -0.24504453 0.19666554\n", - " -0.10503208]\n", - "param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237206 0.03389215 ... -0.35602498 0.25528812\n", - " 0.11344345]\n", - "param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.48457903 0.04466334 -0.19785863 ... -0.0254025 -0.10338341\n", - " -0.29202533]\n", - " [-0.15261276 0.00412052 0.22198747 ... 0.22460426 -0.03752084\n", - " 0.05170784]\n", - " [-0.09337254 0.02530848 0.1263681 ... -0.02056236 0.33342454\n", - " -0.08760723]\n", - " ...\n", - " [-0.28645608 -0.19169135 -0.1361257 ... -0.00444204 -0.06552711\n", - " -0.14726155]\n", - " [ 0.21883707 0.2049045 0.23723911 ... 0.4626113 -0.14110637\n", - " 0.02569831]\n", - " [ 0.37554163 -0.19249167 0.14591683 ... 0.25602737 0.40088275\n", - " 0.41056633]]\n", - "param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237211 0.0338921 ... -0.35602498 0.2552881\n", - " 0.11344352]\n", - "param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[-0.28007814 -0.09206 -0.01297755 ... -0.2557205 -0.2693453\n", - " 0.05862035]\n", - " [-0.34194735 -0.01383794 -0.06490533 ... -0.11063005 0.16226721\n", - " -0.3197178 ]\n", - " [-0.3646778 0.15443833 0.02241019 ... -0.15093157 -0.09886418\n", - " -0.44295847]\n", - " ...\n", - " [-0.01041886 -0.57636976 -0.03988511 ... -0.2260822 0.49646813\n", - " -0.15528557]\n", - " [-0.19385241 -0.56451964 -0.05551083 ... -0.5638106 0.43611372\n", - " -0.61484563]\n", - " [ 0.1051331 -0.4762463 0.11194798 ... -0.26766616 -0.30734932\n", - " 0.17856634]]\n", - "param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.02791309 -0.992517 0.63012564 ... -1.1830902 1.4646478\n", - " 1.6333911 ]\n", - "param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.10834587 -1.7079136 0.81259465 ... -1.4478713 1.455745\n", - " 2.069446 ]\n", - "param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.14363798 -0.06933184 0.02901152 ... -0.19233373 -0.03206367\n", - " -0.00845779]\n", - " [-0.44314507 -0.8921327 -1.031872 ... -0.558997 -0.53070104\n", - " -0.855925 ]\n", - " [ 0.15673254 0.28793585 0.13351494 ... 0.38433537 0.5040767\n", - " 0.11303265]\n", - " ...\n", - " [-0.22923109 -0.62508404 -0.6195032 ... -0.6876448 -0.41718128\n", - " -0.74844164]\n", - " [ 0.18024652 0.45618314 0.81391454 ... 0.5780604 0.87566674\n", - " 0.71526295]\n", - " [ 0.3763076 0.54033077 0.9940485 ... 1.087821 0.72288674\n", - " 1.2852117 ]]\n", - "param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.10834593 -1.7079139 0.8125948 ... -1.4478711 1.4557447\n", - " 2.0694466 ]\n", - "param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: [[ 1.4382483e-02 2.0160766e-02 1.2322801e-02 ... 1.0075266e-02\n", - " 7.4421698e-03 -2.3925617e+01]\n", - " [ 3.7887424e-02 5.7105277e-02 2.8803380e-02 ... 2.4820438e-02\n", - " 1.8560058e-02 -5.0687141e+01]\n", - " [ 4.5566272e-02 5.4415584e-02 3.2858539e-02 ... 3.2725763e-02\n", - " 2.1536341e-02 -6.1036335e+01]\n", - " ...\n", - " [ 2.8015019e-02 3.5967816e-02 2.3228688e-02 ... 2.1284629e-02\n", - " 1.3860047e-02 -5.2543671e+01]\n", - " [ 2.8445240e-02 4.2448867e-02 2.7125146e-02 ... 2.2253662e-02\n", - " 1.7470375e-02 -4.3619675e+01]\n", - " [ 4.7438074e-02 5.8287360e-02 3.4546286e-02 ... 3.0827176e-02\n", - " 2.2168703e-02 -6.7901680e+01]]\n", - "param grad: fc.bias: shape: [4299] stop_grad: False grad: [ 8.8967547e-02 1.0697905e-01 6.5251388e-02 ... 6.1503030e-02\n", - " 4.3404289e-02 -1.3512518e+02]\n" - ] - } - ], - "source": [ - "loss.backward(retain_graph=False)\n", - "for n, p in dp_model.named_parameters():\n", - " print(\n", - " f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "selected-crazy", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1.]\n" - ] - } - ], - "source": [ - "print(loss.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bottom-engineer", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "stuffed-yeast", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file diff --git a/.notebook/u2_confermer_model_wenet.ipynb b/.notebook/u2_confermer_model_wenet.ipynb deleted file mode 100644 index a425e16cb..000000000 --- a/.notebook/u2_confermer_model_wenet.ipynb +++ /dev/null @@ -1,4608 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "choice-grade", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/DeepSpeech-2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "broke-broad", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "register user softmax to paddle, remove this when fixed!\n", - "register user log_softmax to paddle, remove this when fixed!\n", - "register user sigmoid to paddle, remove this when fixed!\n", - "register user log_sigmoid to paddle, remove this when fixed!\n", - "register user relu to paddle, remove this when fixed!\n", - "override cat of paddle if exists or register, remove this when fixed!\n", - "override item of paddle.Tensor if exists or register, remove this when fixed!\n", - "override long of paddle.Tensor if exists or register, remove this when fixed!\n", - "override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle if exists or register, remove this when fixed!\n", - "override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "register user view to paddle.Tensor, remove this when fixed!\n", - "register user view_as to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "register user fill_ to paddle.Tensor, remove this when fixed!\n", - "register user repeat to paddle.Tensor, remove this when fixed!\n", - "register user softmax to paddle.Tensor, remove this when fixed!\n", - "register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "register user relu to paddle.Tensor, remove this when fixed!\n", - "register user type_as to paddle.Tensor, remove this when fixed!\n", - "register user to to paddle.Tensor, remove this when fixed!\n", - "register user float to paddle.Tensor, remove this when fixed!\n", - "register user tolist to paddle.Tensor, remove this when fixed!\n", - "register user glu to paddle.nn.functional, remove this when fixed!\n", - "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "register user Module to paddle.nn, remove this when fixed!\n", - "register user ModuleList to paddle.nn, remove this when fixed!\n", - "register user GLU to paddle.nn, remove this when fixed!\n", - "register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "register user export to paddle.jit, remove this when fixed!\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "import paddle\n", - "from yacs.config import CfgNode as CN\n", - "\n", - "from deepspeech.models.u2 import U2Model\n", - "from deepspeech.utils.layer_tools import print_params\n", - "from deepspeech.utils.layer_tools import summary" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "permanent-summary", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "[INFO 2021/04/20 03:32:21 u2.py:834] U2 Encoder type: conformer\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", - "encoder.embed.conv.0.bias | [256] | 256 | True\n", - "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", - "encoder.embed.conv.2.bias | [256] | 256 | True\n", - "encoder.embed.out.0.weight | [4864, 256] | 1245184 | True\n", - "encoder.embed.out.0.bias | [256] | 256 | True\n", - "encoder.after_norm.weight | [256] | 256 | True\n", - "encoder.after_norm.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.0.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.1.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.2.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.3.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.4.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.5.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.6.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.7.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.8.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.9.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.10.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.11.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n", - "decoder.embed.0.weight | [4233, 256] | 1083648 | True\n", - "decoder.after_norm.weight | [256] | 256 | True\n", - "decoder.after_norm.bias | [256] | 256 | True\n", - "decoder.output_layer.weight | [256, 4233] | 1083648 | True\n", - "decoder.output_layer.bias | [4233] | 4233 | True\n", - "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n", - "ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n", - "ctc.ctc_lo.bias | [4233] | 4233 | True\n", - "Total parameters: 687.0, 49355282.0 elements.\n" - ] - } - ], - "source": [ - "conf_str='examples/aishell/s1/conf/conformer.yaml'\n", - "cfg = CN().load_cfg(open(conf_str))\n", - "cfg.model.input_dim = 80\n", - "cfg.model.output_dim = 4233\n", - "cfg.model.cmvn_file = \"/workspace/wenet/examples/aishell/s0/raw_wav/train/global_cmvn\"\n", - "cfg.model.cmvn_file_type = 'json'\n", - "cfg.freeze()\n", - "\n", - "model = U2Model(cfg.model)\n", - "print_params(model)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "sapphire-agent", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.global_cmvn.mean | [80] | 80\n", - "encoder.global_cmvn.istd | [80] | 80\n", - "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304\n", - "encoder.embed.conv.0.bias | [256] | 256\n", - "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824\n", - "encoder.embed.conv.2.bias | [256] | 256\n", - "encoder.embed.out.0.weight | [4864, 256] | 1245184\n", - "encoder.embed.out.0.bias | [256] | 256\n", - "encoder.after_norm.weight | [256] | 256\n", - "encoder.after_norm.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.0.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.0.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.0.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.0.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.0.norm_ff.weight | [256] | 256\n", - "encoder.encoders.0.norm_ff.bias | [256] | 256\n", - "encoder.encoders.0.norm_mha.weight | [256] | 256\n", - "encoder.encoders.0.norm_mha.bias | [256] | 256\n", - "encoder.encoders.0.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.0.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.0.norm_conv.weight | [256] | 256\n", - "encoder.encoders.0.norm_conv.bias | [256] | 256\n", - "encoder.encoders.0.norm_final.weight | [256] | 256\n", - "encoder.encoders.0.norm_final.bias | [256] | 256\n", - "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.0.concat_linear.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.1.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.1.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.1.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.1.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.1.norm_ff.weight | [256] | 256\n", - "encoder.encoders.1.norm_ff.bias | [256] | 256\n", - "encoder.encoders.1.norm_mha.weight | [256] | 256\n", - "encoder.encoders.1.norm_mha.bias | [256] | 256\n", - "encoder.encoders.1.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.1.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.1.norm_conv.weight | [256] | 256\n", - "encoder.encoders.1.norm_conv.bias | [256] | 256\n", - "encoder.encoders.1.norm_final.weight | [256] | 256\n", - "encoder.encoders.1.norm_final.bias | [256] | 256\n", - "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.1.concat_linear.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.2.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.2.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.2.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.2.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.2.norm_ff.weight | [256] | 256\n", - "encoder.encoders.2.norm_ff.bias | [256] | 256\n", - "encoder.encoders.2.norm_mha.weight | [256] | 256\n", - "encoder.encoders.2.norm_mha.bias | [256] | 256\n", - "encoder.encoders.2.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.2.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.2.norm_conv.weight | [256] | 256\n", - "encoder.encoders.2.norm_conv.bias | [256] | 256\n", - "encoder.encoders.2.norm_final.weight | [256] | 256\n", - "encoder.encoders.2.norm_final.bias | [256] | 256\n", - "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.2.concat_linear.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.3.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.3.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.3.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.3.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.3.norm_ff.weight | [256] | 256\n", - "encoder.encoders.3.norm_ff.bias | [256] | 256\n", - "encoder.encoders.3.norm_mha.weight | [256] | 256\n", - "encoder.encoders.3.norm_mha.bias | [256] | 256\n", - "encoder.encoders.3.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.3.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.3.norm_conv.weight | [256] | 256\n", - "encoder.encoders.3.norm_conv.bias | [256] | 256\n", - "encoder.encoders.3.norm_final.weight | [256] | 256\n", - "encoder.encoders.3.norm_final.bias | [256] | 256\n", - "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.3.concat_linear.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.4.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.4.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.4.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.4.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.4.norm_ff.weight | [256] | 256\n", - "encoder.encoders.4.norm_ff.bias | [256] | 256\n", - "encoder.encoders.4.norm_mha.weight | [256] | 256\n", - "encoder.encoders.4.norm_mha.bias | [256] | 256\n", - "encoder.encoders.4.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.4.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.4.norm_conv.weight | [256] | 256\n", - "encoder.encoders.4.norm_conv.bias | [256] | 256\n", - "encoder.encoders.4.norm_final.weight | [256] | 256\n", - "encoder.encoders.4.norm_final.bias | [256] | 256\n", - "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.4.concat_linear.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.5.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.5.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.5.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.5.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.5.norm_ff.weight | [256] | 256\n", - "encoder.encoders.5.norm_ff.bias | [256] | 256\n", - "encoder.encoders.5.norm_mha.weight | [256] | 256\n", - "encoder.encoders.5.norm_mha.bias | [256] | 256\n", - "encoder.encoders.5.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.5.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.5.norm_conv.weight | [256] | 256\n", - "encoder.encoders.5.norm_conv.bias | [256] | 256\n", - "encoder.encoders.5.norm_final.weight | [256] | 256\n", - "encoder.encoders.5.norm_final.bias | [256] | 256\n", - "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.5.concat_linear.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.6.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.6.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.6.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.6.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.6.norm_ff.weight | [256] | 256\n", - "encoder.encoders.6.norm_ff.bias | [256] | 256\n", - "encoder.encoders.6.norm_mha.weight | [256] | 256\n", - "encoder.encoders.6.norm_mha.bias | [256] | 256\n", - "encoder.encoders.6.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.6.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.6.norm_conv.weight | [256] | 256\n", - "encoder.encoders.6.norm_conv.bias | [256] | 256\n", - "encoder.encoders.6.norm_final.weight | [256] | 256\n", - "encoder.encoders.6.norm_final.bias | [256] | 256\n", - "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.6.concat_linear.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.7.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.7.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.7.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.7.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.7.norm_ff.weight | [256] | 256\n", - "encoder.encoders.7.norm_ff.bias | [256] | 256\n", - "encoder.encoders.7.norm_mha.weight | [256] | 256\n", - "encoder.encoders.7.norm_mha.bias | [256] | 256\n", - "encoder.encoders.7.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.7.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.7.norm_conv.weight | [256] | 256\n", - "encoder.encoders.7.norm_conv.bias | [256] | 256\n", - "encoder.encoders.7.norm_final.weight | [256] | 256\n", - "encoder.encoders.7.norm_final.bias | [256] | 256\n", - "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.7.concat_linear.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.8.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.8.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.8.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.8.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.8.norm_ff.weight | [256] | 256\n", - "encoder.encoders.8.norm_ff.bias | [256] | 256\n", - "encoder.encoders.8.norm_mha.weight | [256] | 256\n", - "encoder.encoders.8.norm_mha.bias | [256] | 256\n", - "encoder.encoders.8.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.8.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.8.norm_conv.weight | [256] | 256\n", - "encoder.encoders.8.norm_conv.bias | [256] | 256\n", - "encoder.encoders.8.norm_final.weight | [256] | 256\n", - "encoder.encoders.8.norm_final.bias | [256] | 256\n", - "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.8.concat_linear.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.9.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.9.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.9.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.9.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.9.norm_ff.weight | [256] | 256\n", - "encoder.encoders.9.norm_ff.bias | [256] | 256\n", - "encoder.encoders.9.norm_mha.weight | [256] | 256\n", - "encoder.encoders.9.norm_mha.bias | [256] | 256\n", - "encoder.encoders.9.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.9.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.9.norm_conv.weight | [256] | 256\n", - "encoder.encoders.9.norm_conv.bias | [256] | 256\n", - "encoder.encoders.9.norm_final.weight | [256] | 256\n", - "encoder.encoders.9.norm_final.bias | [256] | 256\n", - "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.9.concat_linear.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.10.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.10.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.10.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.10.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.10.norm_ff.weight | [256] | 256\n", - "encoder.encoders.10.norm_ff.bias | [256] | 256\n", - "encoder.encoders.10.norm_mha.weight | [256] | 256\n", - "encoder.encoders.10.norm_mha.bias | [256] | 256\n", - "encoder.encoders.10.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.10.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.10.norm_conv.weight | [256] | 256\n", - "encoder.encoders.10.norm_conv.bias | [256] | 256\n", - "encoder.encoders.10.norm_final.weight | [256] | 256\n", - "encoder.encoders.10.norm_final.bias | [256] | 256\n", - "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.10.concat_linear.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.11.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.11.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.11.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.11.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.11.norm_ff.weight | [256] | 256\n", - "encoder.encoders.11.norm_ff.bias | [256] | 256\n", - "encoder.encoders.11.norm_mha.weight | [256] | 256\n", - "encoder.encoders.11.norm_mha.bias | [256] | 256\n", - "encoder.encoders.11.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.11.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.11.norm_conv.weight | [256] | 256\n", - "encoder.encoders.11.norm_conv.bias | [256] | 256\n", - "encoder.encoders.11.norm_final.weight | [256] | 256\n", - "encoder.encoders.11.norm_final.bias | [256] | 256\n", - "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.11.concat_linear.bias | [256] | 256\n", - "decoder.embed.0.weight | [4233, 256] | 1083648\n", - "decoder.after_norm.weight | [256] | 256\n", - "decoder.after_norm.bias | [256] | 256\n", - "decoder.output_layer.weight | [256, 4233] | 1083648\n", - "decoder.output_layer.bias | [4233] | 4233\n", - "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.0.norm1.weight | [256] | 256\n", - "decoder.decoders.0.norm1.bias | [256] | 256\n", - "decoder.decoders.0.norm2.weight | [256] | 256\n", - "decoder.decoders.0.norm2.bias | [256] | 256\n", - "decoder.decoders.0.norm3.weight | [256] | 256\n", - "decoder.decoders.0.norm3.bias | [256] | 256\n", - "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.0.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.0.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.1.norm1.weight | [256] | 256\n", - "decoder.decoders.1.norm1.bias | [256] | 256\n", - "decoder.decoders.1.norm2.weight | [256] | 256\n", - "decoder.decoders.1.norm2.bias | [256] | 256\n", - "decoder.decoders.1.norm3.weight | [256] | 256\n", - "decoder.decoders.1.norm3.bias | [256] | 256\n", - "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.1.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.1.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.2.norm1.weight | [256] | 256\n", - "decoder.decoders.2.norm1.bias | [256] | 256\n", - "decoder.decoders.2.norm2.weight | [256] | 256\n", - "decoder.decoders.2.norm2.bias | [256] | 256\n", - "decoder.decoders.2.norm3.weight | [256] | 256\n", - "decoder.decoders.2.norm3.bias | [256] | 256\n", - "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.2.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.2.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.3.norm1.weight | [256] | 256\n", - "decoder.decoders.3.norm1.bias | [256] | 256\n", - "decoder.decoders.3.norm2.weight | [256] | 256\n", - "decoder.decoders.3.norm2.bias | [256] | 256\n", - "decoder.decoders.3.norm3.weight | [256] | 256\n", - "decoder.decoders.3.norm3.bias | [256] | 256\n", - "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.3.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.3.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.4.norm1.weight | [256] | 256\n", - "decoder.decoders.4.norm1.bias | [256] | 256\n", - "decoder.decoders.4.norm2.weight | [256] | 256\n", - "decoder.decoders.4.norm2.bias | [256] | 256\n", - "decoder.decoders.4.norm3.weight | [256] | 256\n", - "decoder.decoders.4.norm3.bias | [256] | 256\n", - "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.4.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.4.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.5.norm1.weight | [256] | 256\n", - "decoder.decoders.5.norm1.bias | [256] | 256\n", - "decoder.decoders.5.norm2.weight | [256] | 256\n", - "decoder.decoders.5.norm2.bias | [256] | 256\n", - "decoder.decoders.5.norm3.weight | [256] | 256\n", - "decoder.decoders.5.norm3.bias | [256] | 256\n", - "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.5.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.5.concat_linear2.bias | [256] | 256\n", - "ctc.ctc_lo.weight | [256, 4233] | 1083648\n", - "ctc.ctc_lo.bias | [4233] | 4233\n", - "Total parameters: 689, 49355442 elements.\n" - ] - } - ], - "source": [ - "summary(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "ruled-invitation", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "U2Model(\n", - " (encoder): ConformerEncoder(\n", - " (global_cmvn): GlobalCMVN()\n", - " (embed): Conv2dSubsampling4(\n", - " (pos_enc): RelPositionalEncoding(\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " )\n", - " (conv): Sequential(\n", - " (0): Conv2D(1, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n", - " (1): ReLU()\n", - " (2): Conv2D(256, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n", - " (3): ReLU()\n", - " )\n", - " (out): Sequential(\n", - " (0): Linear(in_features=4864, out_features=256, dtype=float32)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (encoders): LayerList(\n", - " (0): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (1): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (2): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (3): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (4): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (5): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (6): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (7): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (8): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (9): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (10): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (11): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " )\n", - " )\n", - " (decoder): TransformerDecoder(\n", - " (embed): Sequential(\n", - " (0): Embedding(4233, 256, sparse=False)\n", - " (1): PositionalEncoding(\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (output_layer): Linear(in_features=256, out_features=4233, dtype=float32)\n", - " (decoders): LayerList(\n", - " (0): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (1): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (2): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (3): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (4): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (5): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " )\n", - " )\n", - " (ctc): CTCDecoder(\n", - " (ctc_lo): Linear(in_features=256, out_features=4233, dtype=float32)\n", - " (criterion): CTCLoss(\n", - " (loss): CTCLoss()\n", - " )\n", - " )\n", - " (criterion_att): LabelSmoothingLoss(\n", - " (criterion): KLDivLoss()\n", - " )\n", - ")\n" - ] - } - ], - "source": [ - "print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "fossil-means", - "metadata": {}, - "outputs": [], - "source": [ - "# load feat" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "fleet-despite", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "compute_cmvn_loader_test.ipynb encoder.npz\r\n", - "dataloader.ipynb hack_api_test.ipynb\r\n", - "dataloader_with_tokens_tokenids.ipynb jit_infer.ipynb\r\n", - "data.npz layer_norm_test.ipynb\r\n", - "decoder.npz Linear_test.ipynb\r\n", - "enc_0_ff_out.npz mask_and_masked_fill_test.ipynb\r\n", - "enc_0_norm_ff.npz model.npz\r\n", - "enc_0.npz position_embeding_check.ipynb\r\n", - "enc_0_selattn_out.npz python_test.ipynb\r\n", - "enc_2.npz train_test.ipynb\r\n", - "enc_all.npz u2_model.ipynb\r\n", - "enc_embed.npz\r\n" - ] - } - ], - "source": [ - "%ls .notebook" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "abroad-oracle", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['BAC009S0739W0246' 'BAC009S0727W0424' 'BAC009S0753W0412'\n", - " 'BAC009S0756W0206' 'BAC009S0740W0414' 'BAC009S0728W0426'\n", - " 'BAC009S0739W0214' 'BAC009S0753W0423' 'BAC009S0734W0201'\n", - " 'BAC009S0740W0427' 'BAC009S0730W0423' 'BAC009S0728W0367'\n", - " 'BAC009S0730W0418' 'BAC009S0727W0157' 'BAC009S0749W0409'\n", - " 'BAC009S0727W0418']\n", - "(16, 207, 80)\n", - "[[[ 8.994624 9.538309 9.191589 ... 10.507416 9.563305 8.256403 ]\n", - " [ 9.798841 10.405224 9.26511 ... 10.251211 9.543982 8.873768 ]\n", - " [10.6890745 10.395469 8.053548 ... 9.906749 10.064903 8.050915 ]\n", - " ...\n", - " [ 9.217986 9.65069 8.505259 ... 9.687183 8.742463 7.9865475]\n", - " [10.129122 9.935194 9.37982 ... 9.563894 9.825992 8.979543 ]\n", - " [ 9.095531 7.1338377 9.468001 ... 9.472748 9.021235 7.447914 ]]\n", - "\n", - " [[11.430976 10.671858 6.0841026 ... 9.382682 8.729745 7.5315614]\n", - " [ 9.731717 7.8104815 7.5714607 ... 10.043035 9.243595 7.3540792]\n", - " [10.65017 10.600604 8.467784 ... 9.281448 9.186885 8.070343 ]\n", - " ...\n", - " [ 9.096987 9.2637 8.075275 ... 8.431845 8.370505 8.002926 ]\n", - " [10.461651 10.147784 6.7693496 ... 9.779426 9.577453 8.080652 ]\n", - " [ 7.794432 5.621059 7.9750648 ... 9.997245 9.849678 8.031287 ]]\n", - "\n", - " [[ 7.3455667 7.896357 7.5795946 ... 11.631024 10.451254 9.123633 ]\n", - " [ 8.628678 8.4630575 7.499242 ... 12.415986 10.975749 8.9425745]\n", - " [ 9.831394 10.2812805 8.97241 ... 12.1386795 10.40175 9.005517 ]\n", - " ...\n", - " [ 7.089641 7.405548 6.8142557 ... 9.325196 9.273162 8.353427 ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]\n", - "\n", - " ...\n", - "\n", - " [[10.933237 10.464394 7.7202725 ... 10.348816 9.302338 7.1553144]\n", - " [10.449866 9.907033 9.029272 ... 9.952465 9.414051 7.559279 ]\n", - " [10.487655 9.81259 9.895244 ... 9.58662 9.341254 7.7849016]\n", - " ...\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]\n", - "\n", - " [[ 9.944384 9.585867 8.220328 ... 11.588647 11.045029 8.817075 ]\n", - " [ 7.678356 8.322397 7.533047 ... 11.055085 10.535685 9.27465 ]\n", - " [ 8.626197 9.675917 9.841045 ... 11.378827 10.922112 8.991444 ]\n", - " ...\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]\n", - "\n", - " [[ 8.107938 7.759043 6.710301 ... 12.650573 11.466156 11.061517 ]\n", - " [11.380332 11.222007 8.658889 ... 12.810616 12.222216 11.689288 ]\n", - " [10.677676 9.920579 8.046089 ... 13.572894 12.5624075 11.155033 ]\n", - " ...\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]]\n", - "[207 207 205 205 203 203 198 197 195 188 186 186 185 180 166 163]\n", - "[[2995 3116 1209 565 -1 -1]\n", - " [ 236 1176 331 66 3925 4077]\n", - " [2693 524 234 1145 366 -1]\n", - " [3875 4211 3062 700 -1 -1]\n", - " [ 272 987 1134 494 2959 -1]\n", - " [1936 3715 120 2553 2695 2710]\n", - " [ 25 1149 3930 -1 -1 -1]\n", - " [1753 1778 1237 482 3925 110]\n", - " [3703 2 565 3827 -1 -1]\n", - " [1150 2734 10 2478 3490 -1]\n", - " [ 426 811 95 489 144 -1]\n", - " [2313 2006 489 975 -1 -1]\n", - " [3702 3414 205 1488 2966 1347]\n", - " [ 70 1741 702 1666 -1 -1]\n", - " [ 703 1778 1030 849 -1 -1]\n", - " [ 814 1674 115 3827 -1 -1]]\n", - "[4 6 5 4 5 6 3 6 4 5 5 4 6 4 4 4]\n" - ] - } - ], - "source": [ - "data = np.load('.notebook/data.npz', allow_pickle=True)\n", - "keys=data['keys']\n", - "feat=data['feat']\n", - "feat_len=data['feat_len']\n", - "text=data['text']\n", - "text_len=data['text_len']\n", - "print(keys)\n", - "print(feat.shape)\n", - "print(feat)\n", - "print(feat_len)\n", - "print(text)\n", - "print(text_len)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "false-instrument", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "arctic-proxy", - "metadata": {}, - "outputs": [], - "source": [ - "# ['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n", - "# torch.Size([16, 207, 80])\n", - "# tensor([[[ 8.9946, 9.5383, 9.1916, ..., 10.5074, 9.5633, 8.2564],\n", - "# [ 9.7988, 10.4052, 9.2651, ..., 10.2512, 9.5440, 8.8738],\n", - "# [10.6891, 10.3955, 8.0535, ..., 9.9067, 10.0649, 8.0509],\n", - "# ...,\n", - "# [ 9.2180, 9.6507, 8.5053, ..., 9.6872, 8.7425, 7.9865],\n", - "# [10.1291, 9.9352, 9.3798, ..., 9.5639, 9.8260, 8.9795],\n", - "# [ 9.0955, 7.1338, 9.4680, ..., 9.4727, 9.0212, 7.4479]],\n", - "\n", - "# [[11.4310, 10.6719, 6.0841, ..., 9.3827, 8.7297, 7.5316],\n", - "# [ 9.7317, 7.8105, 7.5715, ..., 10.0430, 9.2436, 7.3541],\n", - "# [10.6502, 10.6006, 8.4678, ..., 9.2814, 9.1869, 8.0703],\n", - "# ...,\n", - "# [ 9.0970, 9.2637, 8.0753, ..., 8.4318, 8.3705, 8.0029],\n", - "# [10.4617, 10.1478, 6.7693, ..., 9.7794, 9.5775, 8.0807],\n", - "# [ 7.7944, 5.6211, 7.9751, ..., 9.9972, 9.8497, 8.0313]],\n", - "\n", - "# [[ 7.3456, 7.8964, 7.5796, ..., 11.6310, 10.4513, 9.1236],\n", - "# [ 8.6287, 8.4631, 7.4992, ..., 12.4160, 10.9757, 8.9426],\n", - "# [ 9.8314, 10.2813, 8.9724, ..., 12.1387, 10.4017, 9.0055],\n", - "# ...,\n", - "# [ 7.0896, 7.4055, 6.8143, ..., 9.3252, 9.2732, 8.3534],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - "# ...,\n", - "\n", - "# [[10.9332, 10.4644, 7.7203, ..., 10.3488, 9.3023, 7.1553],\n", - "# [10.4499, 9.9070, 9.0293, ..., 9.9525, 9.4141, 7.5593],\n", - "# [10.4877, 9.8126, 9.8952, ..., 9.5866, 9.3413, 7.7849],\n", - "# ...,\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - "# [[ 9.9444, 9.5859, 8.2203, ..., 11.5886, 11.0450, 8.8171],\n", - "# [ 7.6784, 8.3224, 7.5330, ..., 11.0551, 10.5357, 9.2746],\n", - "# [ 8.6262, 9.6759, 9.8410, ..., 11.3788, 10.9221, 8.9914],\n", - "# ...,\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - "# [[ 8.1079, 7.7590, 6.7103, ..., 12.6506, 11.4662, 11.0615],\n", - "# [11.3803, 11.2220, 8.6589, ..., 12.8106, 12.2222, 11.6893],\n", - "# [10.6777, 9.9206, 8.0461, ..., 13.5729, 12.5624, 11.1550],\n", - "# ...,\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n", - "# tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n", - "# 166, 163], dtype=torch.int32)\n", - "# tensor([[2995, 3116, 1209, 565, -1, -1],\n", - "# [ 236, 1176, 331, 66, 3925, 4077],\n", - "# [2693, 524, 234, 1145, 366, -1],\n", - "# [3875, 4211, 3062, 700, -1, -1],\n", - "# [ 272, 987, 1134, 494, 2959, -1],\n", - "# [1936, 3715, 120, 2553, 2695, 2710],\n", - "# [ 25, 1149, 3930, -1, -1, -1],\n", - "# [1753, 1778, 1237, 482, 3925, 110],\n", - "# [3703, 2, 565, 3827, -1, -1],\n", - "# [1150, 2734, 10, 2478, 3490, -1],\n", - "# [ 426, 811, 95, 489, 144, -1],\n", - "# [2313, 2006, 489, 975, -1, -1],\n", - "# [3702, 3414, 205, 1488, 2966, 1347],\n", - "# [ 70, 1741, 702, 1666, -1, -1],\n", - "# [ 703, 1778, 1030, 849, -1, -1],\n", - "# [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n", - "# tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "seasonal-switch", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "defined-brooks", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "compute_cmvn_loader_test.ipynb\t encoder.npz\r\n", - "dataloader.ipynb\t\t hack_api_test.ipynb\r\n", - "dataloader_with_tokens_tokenids.ipynb jit_infer.ipynb\r\n", - "data.npz\t\t\t layer_norm_test.ipynb\r\n", - "decoder.npz\t\t\t Linear_test.ipynb\r\n", - "enc_0_ff_out.npz\t\t mask_and_masked_fill_test.ipynb\r\n", - "enc_0_norm_ff.npz\t\t model.npz\r\n", - "enc_0.npz\t\t\t position_embeding_check.ipynb\r\n", - "enc_0_selattn_out.npz\t\t python_test.ipynb\r\n", - "enc_2.npz\t\t\t train_test.ipynb\r\n", - "enc_all.npz\t\t\t u2_model.ipynb\r\n", - "enc_embed.npz\r\n" - ] - } - ], - "source": [ - "# load model param\n", - "!ls .notebook\n", - "data = np.load('.notebook/model.npz', allow_pickle=True)\n", - "state_dict = data['state'].item()\n", - "\n", - "for key, _ in model.state_dict().items():\n", - " if key not in state_dict:\n", - " print(f\"{key} not find.\")\n", - "\n", - "model.set_state_dict(state_dict)\n", - "\n", - "now_state_dict = model.state_dict()\n", - "for key, value in now_state_dict.items():\n", - " if not np.allclose(value.numpy(), state_dict[key]):\n", - " print(key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "exempt-viewer", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "confident-piano", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/framework.py:687: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " elif dtype == np.bool:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [142.48880005]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [41.84146118]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [377.33258057])\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:238: UserWarning: The dtype of left and right variables are not the same, left dtype is VarType.FP32, but right dtype is VarType.INT32, the right dtype will convert to VarType.FP32\n", - " format(lhs_dtype, rhs_dtype, lhs_dtype))\n" - ] - } - ], - "source": [ - "# compute loss\n", - "import paddle\n", - "feat=paddle.to_tensor(feat)\n", - "feat_len=paddle.to_tensor(feat_len, dtype='int64')\n", - "text=paddle.to_tensor(text, dtype='int64')\n", - "text_len=paddle.to_tensor(text_len, dtype='int64')\n", - "\n", - "model.eval()\n", - "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", - " text, text_len)\n", - "print(total_loss, attention_loss, ctc_loss )" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "better-senator", - "metadata": {}, - "outputs": [], - "source": [ - "# tensor(142.4888, device='cuda:0', grad_fn=) \n", - "# tensor(41.8415, device='cuda:0', grad_fn=) \n", - "# tensor(377.3326, device='cuda:0', grad_fn=)\n", - "# 142.4888 41.84146 377.33258" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "related-banking", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "olympic-problem", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[16, 51, 256]\n", - "[16, 1, 51]\n", - "Tensor(shape=[51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[-0.70194179, 0.56254166, 0.68803459, ..., 1.12373221, 0.78039235, 1.13693869],\n", - " [-0.77877808, 0.39126658, 0.71887815, ..., 1.25188220, 0.88616788, 1.31734526],\n", - " [-0.95908946, 0.63460249, 0.87671334, ..., 0.98183727, 0.74401081, 1.29032660],\n", - " ...,\n", - " [-1.07322502, 0.67236906, 0.92303109, ..., 0.90754563, 0.81767166, 1.32396567],\n", - " [-1.16541159, 0.68199694, 0.69394493, ..., 1.22383487, 0.80282891, 1.45065081],\n", - " [-1.27320945, 0.71458030, 0.75819558, ..., 0.94154912, 0.87748396, 1.26230514]])\n" - ] - } - ], - "source": [ - "# ecnoder\n", - "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n", - "print(encoder_out.shape)\n", - "print(encoder_mask.shape)\n", - "print(encoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "shaped-alaska", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "deepspeech examples README_cn.md\tsetup.sh tools\r\n", - "docs\t LICENSE README.md\t\ttests\t utils\r\n", - "env.sh\t log requirements.txt\tthird_party\r\n" - ] - } - ], - "source": [ - "!ls\n", - "data = np.load('.notebook/encoder.npz', allow_pickle=True)\n", - "torch_mask = data['mask']\n", - "torch_encoder_out = data['out']" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "federal-rover", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "None\n" - ] - } - ], - "source": [ - "print(np.testing.assert_equal(torch_mask, encoder_mask.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "regulated-interstate", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "False\n", - "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", - " 1.1369387 ]\n", - " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", - " 1.3173454 ]\n", - " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", - " 1.2903274 ]\n", - " ...\n", - " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", - " 1.3239657 ]\n", - " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", - " 1.4506509 ]\n", - " [-1.2732087 0.71458083 0.7581961 ... 0.9415482 0.877484\n", - " 1.2623053 ]]\n", - "----\n", - "[[-0.7019418 0.56254166 0.6880346 ... 1.1237322 0.78039235\n", - " 1.1369387 ]\n", - " [-0.7787781 0.39126658 0.71887815 ... 1.2518822 0.8861679\n", - " 1.3173453 ]\n", - " [-0.95908946 0.6346025 0.87671334 ... 0.9818373 0.7440108\n", - " 1.2903266 ]\n", - " ...\n", - " [-1.073225 0.67236906 0.9230311 ... 0.9075456 0.81767166\n", - " 1.3239657 ]\n", - " [-1.1654116 0.68199694 0.69394493 ... 1.2238349 0.8028289\n", - " 1.4506508 ]\n", - " [-1.2732095 0.7145803 0.7581956 ... 0.9415491 0.87748396\n", - " 1.2623051 ]]\n", - "True\n", - "False\n" - ] - } - ], - "source": [ - "print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n", - "print(torch_encoder_out[0])\n", - "print(\"----\")\n", - "print(encoder_out.numpy()[0])\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5, rtol=1e-6))\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6, rtol=1e-6))" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "proof-scheduling", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [377.33258057])\n", - "[1.]\n", - "[[ 3.16902876e+00 -1.51763987e-02 4.91095744e-02 ... -2.47971853e-03\n", - " -5.93360700e-03 -7.26609165e-03]\n", - " [-1.74184477e+00 7.75874173e-03 -4.49434854e-02 ... 9.92412097e-04\n", - " 2.46337592e-03 2.31892057e-03]\n", - " [-2.33343339e+00 1.30475955e-02 -2.66557075e-02 ... 2.27532350e-03\n", - " 5.76924905e-03 7.48788286e-03]\n", - " ...\n", - " [-4.30358458e+00 2.46054661e-02 -9.00950655e-02 ... 4.43156436e-03\n", - " 1.16122244e-02 1.44715561e-02]\n", - " [-3.36921120e+00 1.73153952e-02 -6.36872873e-02 ... 3.28363618e-03\n", - " 8.58010259e-03 1.07794888e-02]\n", - " [-6.62045336e+00 3.49955931e-02 -1.23962618e-01 ... 6.36671018e-03\n", - " 1.60814095e-02 2.03891303e-02]]\n", - "[-4.3777819e+00 2.3245810e-02 -9.3339294e-02 ... 4.2569344e-03\n", - " 1.0919910e-02 1.3787797e-02]\n" - ] - } - ], - "source": [ - "from paddle.nn import functional as F\n", - "def ctc_loss(logits,\n", - " labels,\n", - " input_lengths,\n", - " label_lengths,\n", - " blank=0,\n", - " reduction='mean',\n", - " norm_by_times=False):\n", - " loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,\n", - " input_lengths, label_lengths)\n", - " loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])\n", - " assert reduction in ['mean', 'sum', 'none']\n", - " if reduction == 'mean':\n", - " loss_out = paddle.mean(loss_out / label_lengths)\n", - " elif reduction == 'sum':\n", - " loss_out = paddle.sum(loss_out)\n", - " return loss_out\n", - "\n", - "F.ctc_loss = ctc_loss\n", - "\n", - "torch_mask_t = paddle.to_tensor(torch_mask, dtype='int64')\n", - "encoder_out_lens = torch_mask_t.squeeze(1).sum(1)\n", - "loss_ctc = model.ctc(paddle.to_tensor(torch_encoder_out), encoder_out_lens, text, text_len)\n", - "print(loss_ctc)\n", - "loss_ctc.backward()\n", - "print(loss_ctc.grad)\n", - "print(model.ctc.ctc_lo.weight.grad)\n", - "print(model.ctc.ctc_lo.bias.grad)\n", - "\n", - "\n", - "# tensor(377.3326, device='cuda:0', grad_fn=)\n", - "# None\n", - "# [[ 3.16902351e+00 -1.51765049e-02 4.91097234e-02 ... -2.47973716e-03\n", - "# -5.93366381e-03 -7.26613170e-03]\n", - "# [-1.74185038e+00 7.75875803e-03 -4.49435972e-02 ... 9.92415240e-04\n", - "# 2.46338220e-03 2.31891591e-03]\n", - "# [-2.33343077e+00 1.30476682e-02 -2.66557615e-02 ... 2.27533933e-03\n", - "# 5.76929189e-03 7.48792710e-03]\n", - "# ...\n", - "# [-4.30356789e+00 2.46056803e-02 -9.00955945e-02 ... 4.43160534e-03\n", - "# 1.16123557e-02 1.44716976e-02]\n", - "# [-3.36919212e+00 1.73155665e-02 -6.36875406e-02 ... 3.28367390e-03\n", - "# 8.58021621e-03 1.07796099e-02]\n", - "# [-6.62039661e+00 3.49958315e-02 -1.23963736e-01 ... 6.36674836e-03\n", - "# 1.60815325e-02 2.03892551e-02]]\n", - "# [-4.3777566e+00 2.3245990e-02 -9.3339972e-02 ... 4.2569702e-03\n", - "# 1.0920014e-02 1.3787906e-02]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "enclosed-consolidation", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "synthetic-hungarian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [41.84146118]) 0.0\n" - ] - } - ], - "source": [ - "loss_att, acc_att = model._calc_att_loss(paddle.to_tensor(torch_encoder_out), paddle.to_tensor(torch_mask),\n", - " text, text_len)\n", - "print(loss_att, acc_att)\n", - "#tensor(41.8416, device='cuda:0', grad_fn=) 0.0" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "indian-sweden", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 202, - "id": "marine-cuisine", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[-3.7638968e-01 -8.2272053e-01 7.4276292e-01 ... 3.4200522e-01\n", - " 1.5034772e-02 4.0337229e-01]\n", - " [-8.7386459e-01 -3.1389427e-01 4.1987866e-01 ... 3.7723729e-01\n", - " -1.4352810e-01 -1.0023664e+00]\n", - " [-4.3505096e-01 3.4504786e-02 -2.8710306e-01 ... 7.7274129e-02\n", - " -1.1672243e+00 -2.6848501e-01]\n", - " ...\n", - " [ 4.2471480e-01 5.8885634e-01 2.0203922e-02 ... 3.7405500e-01\n", - " 4.5470044e-02 -3.7139410e-01]\n", - " [-3.7978446e-01 -8.1084180e-01 7.5725085e-01 ... 2.6038891e-01\n", - " -7.9347193e-04 4.2537671e-01]\n", - " [-3.8279903e-01 -8.1206715e-01 7.4943429e-01 ... 2.6173013e-01\n", - " -1.0499060e-03 4.2678756e-01]]\n" - ] - } - ], - "source": [ - "data = np.load(\".notebook/decoder.npz\", allow_pickle=True)\n", - "torch_decoder_out = data['decoder_out']\n", - "print(torch_decoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 180, - "id": "several-result", - "metadata": {}, - "outputs": [], - "source": [ - "def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,\n", - " ignore_id: int):\n", - " \"\"\"Add and labels.\n", - " Args:\n", - " ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)\n", - " sos (int): index of \n", - " eos (int): index of \n", - " ignore_id (int): index of padding\n", - " Returns:\n", - " ys_in (paddle.Tensor) : (B, Lmax + 1)\n", - " ys_out (paddle.Tensor) : (B, Lmax + 1)\n", - " Examples:\n", - " >>> sos_id = 10\n", - " >>> eos_id = 11\n", - " >>> ignore_id = -1\n", - " >>> ys_pad\n", - " tensor([[ 1, 2, 3, 4, 5],\n", - " [ 4, 5, 6, -1, -1],\n", - " [ 7, 8, 9, -1, -1]], dtype=paddle.int32)\n", - " >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)\n", - " >>> ys_in\n", - " tensor([[10, 1, 2, 3, 4, 5],\n", - " [10, 4, 5, 6, 11, 11],\n", - " [10, 7, 8, 9, 11, 11]])\n", - " >>> ys_out\n", - " tensor([[ 1, 2, 3, 4, 5, 11],\n", - " [ 4, 5, 6, 11, -1, -1],\n", - " [ 7, 8, 9, 11, -1, -1]])\n", - " \"\"\"\n", - " # TODO(Hui Zhang): using comment code, \n", - " #_sos = paddle.to_tensor(\n", - " # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n", - " #_eos = paddle.to_tensor(\n", - " # [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n", - " #ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n", - " #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]\n", - " #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]\n", - " #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)\n", - " B = ys_pad.size(0)\n", - " _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos\n", - " _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos\n", - " ys_in = paddle.cat([_sos, ys_pad], dim=1)\n", - " mask_pad = (ys_in == ignore_id)\n", - " ys_in = ys_in.masked_fill(mask_pad, eos)\n", - " \n", - "\n", - " ys_out = paddle.cat([ys_pad, _eos], dim=1)\n", - " ys_out = ys_out.masked_fill(mask_pad, eos)\n", - " mask_eos = (ys_out == ignore_id)\n", - " ys_out = ys_out.masked_fill(mask_eos, eos)\n", - " ys_out = ys_out.masked_fill(mask_pad, ignore_id)\n", - " return ys_in, ys_out" - ] - }, - { - "cell_type": "code", - "execution_count": 181, - "id": "possible-bulgaria", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [[4232, 2995, 3116, 1209, 565 , 4232, 4232],\n", - " [4232, 236 , 1176, 331 , 66 , 3925, 4077],\n", - " [4232, 2693, 524 , 234 , 1145, 366 , 4232],\n", - " [4232, 3875, 4211, 3062, 700 , 4232, 4232],\n", - " [4232, 272 , 987 , 1134, 494 , 2959, 4232],\n", - " [4232, 1936, 3715, 120 , 2553, 2695, 2710],\n", - " [4232, 25 , 1149, 3930, 4232, 4232, 4232],\n", - " [4232, 1753, 1778, 1237, 482 , 3925, 110 ],\n", - " [4232, 3703, 2 , 565 , 3827, 4232, 4232],\n", - " [4232, 1150, 2734, 10 , 2478, 3490, 4232],\n", - " [4232, 426 , 811 , 95 , 489 , 144 , 4232],\n", - " [4232, 2313, 2006, 489 , 975 , 4232, 4232],\n", - " [4232, 3702, 3414, 205 , 1488, 2966, 1347],\n", - " [4232, 70 , 1741, 702 , 1666, 4232, 4232],\n", - " [4232, 703 , 1778, 1030, 849 , 4232, 4232],\n", - " [4232, 814 , 1674, 115 , 3827, 4232, 4232]])\n", - "Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [[2995, 3116, 1209, 565, 4232, -1 , -1 ],\n", - " [ 236, 1176, 331, 66 , 3925, 4077, 4232],\n", - " [2693, 524, 234, 1145, 366, 4232, -1 ],\n", - " [3875, 4211, 3062, 700, 4232, -1 , -1 ],\n", - " [ 272, 987, 1134, 494, 2959, 4232, -1 ],\n", - " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", - " [ 25 , 1149, 3930, 4232, -1 , -1 , -1 ],\n", - " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", - " [3703, 2 , 565, 3827, 4232, -1 , -1 ],\n", - " [1150, 2734, 10 , 2478, 3490, 4232, -1 ],\n", - " [ 426, 811, 95 , 489, 144, 4232, -1 ],\n", - " [2313, 2006, 489, 975, 4232, -1 , -1 ],\n", - " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", - " [ 70 , 1741, 702, 1666, 4232, -1 , -1 ],\n", - " [ 703, 1778, 1030, 849, 4232, -1 , -1 ],\n", - " [ 814, 1674, 115, 3827, 4232, -1 , -1 ]])\n" - ] - } - ], - "source": [ - "ys_pad = text\n", - "ys_pad_lens = text_len\n", - "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n", - " model.ignore_id)\n", - "ys_in_lens = ys_pad_lens + 1\n", - "print(ys_in_pad)\n", - "print(ys_out_pad)" - ] - }, - { - "cell_type": "code", - "execution_count": 285, - "id": "north-walter", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "False\n", - "True\n", - "False\n", - "[[-3.76389682e-01 -8.22720408e-01 7.42762923e-01 ... 3.42005253e-01\n", - " 1.50350705e-02 4.03372347e-01]\n", - " [-8.73864174e-01 -3.13894272e-01 4.19878662e-01 ... 3.77237231e-01\n", - " -1.43528014e-01 -1.00236630e+00]\n", - " [-4.35050905e-01 3.45046446e-02 -2.87102997e-01 ... 7.72742853e-02\n", - " -1.16722476e+00 -2.68485069e-01]\n", - " ...\n", - " [ 4.24714804e-01 5.88856399e-01 2.02039629e-02 ... 3.74054879e-01\n", - " 4.54700664e-02 -3.71394157e-01]\n", - " [-3.79784584e-01 -8.10841978e-01 7.57250786e-01 ... 2.60389000e-01\n", - " -7.93404877e-04 4.25376773e-01]\n", - " [-3.82798851e-01 -8.12067091e-01 7.49434292e-01 ... 2.61730075e-01\n", - " -1.04988366e-03 4.26787734e-01]]\n", - "---\n", - "[[-3.7638968e-01 -8.2272053e-01 7.4276292e-01 ... 3.4200522e-01\n", - " 1.5034772e-02 4.0337229e-01]\n", - " [-8.7386459e-01 -3.1389427e-01 4.1987866e-01 ... 3.7723729e-01\n", - " -1.4352810e-01 -1.0023664e+00]\n", - " [-4.3505096e-01 3.4504786e-02 -2.8710306e-01 ... 7.7274129e-02\n", - " -1.1672243e+00 -2.6848501e-01]\n", - " ...\n", - " [ 4.2471480e-01 5.8885634e-01 2.0203922e-02 ... 3.7405500e-01\n", - " 4.5470044e-02 -3.7139410e-01]\n", - " [-3.7978446e-01 -8.1084180e-01 7.5725085e-01 ... 2.6038891e-01\n", - " -7.9347193e-04 4.2537671e-01]\n", - " [-3.8279903e-01 -8.1206715e-01 7.4943429e-01 ... 2.6173013e-01\n", - " -1.0499060e-03 4.2678756e-01]]\n" - ] - } - ], - "source": [ - "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", - " ys_in_lens)\n", - "\n", - "print(np.allclose(decoder_out.numpy(), torch_decoder_out))\n", - "print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-6))\n", - "print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-7))\n", - "print(decoder_out.numpy()[0])\n", - "print('---')\n", - "print(torch_decoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "armed-cowboy", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fifty-earth", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "proud-commonwealth", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 183, - "id": "assisted-fortune", - "metadata": {}, - "outputs": [], - "source": [ - "from paddle import nn\n", - "import paddle\n", - "from paddle.nn import functional as F\n", - "\n", - "class LabelSmoothingLoss(nn.Layer):\n", - "\n", - " def __init__(self,\n", - " size: int,\n", - " padding_idx: int,\n", - " smoothing: float,\n", - " normalize_length: bool=False):\n", - " super().__init__()\n", - " self.size = size\n", - " self.padding_idx = padding_idx\n", - " self.smoothing = smoothing\n", - " self.confidence = 1.0 - smoothing\n", - " self.normalize_length = normalize_length\n", - " self.criterion = nn.KLDivLoss(reduction=\"none\")\n", - "\n", - " def forward(self, x: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:\n", - " \"\"\"Compute loss between x and target.\n", - " The model outputs and data labels tensors are flatten to\n", - " (batch*seqlen, class) shape and a mask is applied to the\n", - " padding part which should not be calculated for loss.\n", - " \n", - " Args:\n", - " x (paddle.Tensor): prediction (batch, seqlen, class)\n", - " target (paddle.Tensor):\n", - " target signal masked with self.padding_id (batch, seqlen)\n", - " Returns:\n", - " loss (paddle.Tensor) : The KL loss, scalar float value\n", - " \"\"\"\n", - " B, T, D = paddle.shape(x)\n", - " assert D == self.size\n", - " x = x.reshape((-1, self.size))\n", - " target = target.reshape([-1])\n", - "\n", - " # use zeros_like instead of torch.no_grad() for true_dist,\n", - " # since no_grad() can not be exported by JIT\n", - " true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))\n", - " ignore = target == self.padding_idx # (B,)\n", - " print(self.smoothing / (self.size - 1))\n", - " print(true_dist)\n", - "\n", - " #target = target * (1 - ignore) # avoid -1 index\n", - " target = target.masked_fill(ignore, 0) # avoid -1 index\n", - " \n", - " \n", - " #true_dist += F.one_hot(target, self.size) * self.confidence\n", - " target_mask = F.one_hot(target, self.size)\n", - " true_dist *= (1 - target_mask)\n", - " true_dist += target_mask * self.confidence\n", - " \n", - "\n", - " kl = self.criterion(F.log_softmax(x, axis=1), true_dist)\n", - " \n", - " #TODO(Hui Zhang): sum not support bool type\n", - " #total = len(target) - int(ignore.sum())\n", - " total = len(target) - int(ignore.type_as(target).sum())\n", - " denom = total if self.normalize_length else B\n", - "\n", - " #numer = (kl * (1 - ignore)).sum()\n", - " numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n", - " return numer / denom\n" - ] - }, - { - "cell_type": "code", - "execution_count": 184, - "id": "weighted-delight", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.3629489603024576e-05\n", - "Tensor(shape=[112, 4233], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " ...,\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363]])\n", - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [41.84146118])\n", - "VarType.INT64\n" - ] - } - ], - "source": [ - "criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n", - "loss_att = criteron(paddle.to_tensor(torch_decoder_out), ys_out_pad.astype('int64'))\n", - "print(loss_att)\n", - "print(ys_out_pad.dtype)\n", - "# tensor(41.8416, device='cuda:0', grad_fn=)" - ] - }, - { - "cell_type": "code", - "execution_count": 286, - "id": "dress-shelter", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [41.84146118])\n", - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [41.84146118])\n", - "4233\n", - "-1\n", - "0.1\n", - "False\n" - ] - } - ], - "source": [ - "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", - " ys_in_lens)\n", - "\n", - "loss_att = model.criterion_att(paddle.to_tensor(torch_decoder_out), ys_out_pad)\n", - "print(loss_att)\n", - "\n", - "loss_att = model.criterion_att(decoder_out, ys_out_pad)\n", - "print(loss_att)\n", - "\n", - "print(model.criterion_att.size)\n", - "print(model.criterion_att.padding_idx)\n", - "print(model.criterion_att.smoothing)\n", - "print(model.criterion_att.normalize_length)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "growing-tooth", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "going-hungary", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "naughty-citizenship", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "experimental-emerald", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "adverse-saskatchewan", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "speaking-shelf", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import List\n", - "from typing import Optional\n", - "from typing import Tuple\n", - "\n", - "import paddle\n", - "from paddle import nn\n", - "from typeguard import check_argument_types\n", - "\n", - "from deepspeech.modules.activation import get_activation\n", - "from deepspeech.modules.attention import MultiHeadedAttention\n", - "from deepspeech.modules.attention import RelPositionMultiHeadedAttention\n", - "from deepspeech.modules.conformer_convolution import ConvolutionModule\n", - "from deepspeech.modules.embedding import PositionalEncoding\n", - "from deepspeech.modules.embedding import RelPositionalEncoding\n", - "from deepspeech.modules.encoder_layer import ConformerEncoderLayer\n", - "from deepspeech.modules.encoder_layer import TransformerEncoderLayer\n", - "from deepspeech.modules.mask import add_optional_chunk_mask\n", - "from deepspeech.modules.mask import make_non_pad_mask\n", - "from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward\n", - "from deepspeech.modules.subsampling import Conv2dSubsampling4\n", - "from deepspeech.modules.subsampling import Conv2dSubsampling6\n", - "from deepspeech.modules.subsampling import Conv2dSubsampling8\n", - "from deepspeech.modules.subsampling import LinearNoSubsampling\n", - "\n", - "class BaseEncoder(nn.Layer):\n", - " def __init__(\n", - " self,\n", - " input_size: int,\n", - " output_size: int=256,\n", - " attention_heads: int=4,\n", - " linear_units: int=2048,\n", - " num_blocks: int=6,\n", - " dropout_rate: float=0.1,\n", - " positional_dropout_rate: float=0.1,\n", - " attention_dropout_rate: float=0.0,\n", - " input_layer: str=\"conv2d\",\n", - " pos_enc_layer_type: str=\"abs_pos\",\n", - " normalize_before: bool=True,\n", - " concat_after: bool=False,\n", - " static_chunk_size: int=0,\n", - " use_dynamic_chunk: bool=False,\n", - " global_cmvn: paddle.nn.Layer=None,\n", - " use_dynamic_left_chunk: bool=False, ):\n", - " \"\"\"\n", - " Args:\n", - " input_size (int): input dim, d_feature\n", - " output_size (int): dimension of attention, d_model\n", - " attention_heads (int): the number of heads of multi head attention\n", - " linear_units (int): the hidden units number of position-wise feed\n", - " forward\n", - " num_blocks (int): the number of encoder blocks\n", - " dropout_rate (float): dropout rate\n", - " attention_dropout_rate (float): dropout rate in attention\n", - " positional_dropout_rate (float): dropout rate after adding\n", - " positional encoding\n", - " input_layer (str): input layer type.\n", - " optional [linear, conv2d, conv2d6, conv2d8]\n", - " pos_enc_layer_type (str): Encoder positional encoding layer type.\n", - " opitonal [abs_pos, scaled_abs_pos, rel_pos]\n", - " normalize_before (bool):\n", - " True: use layer_norm before each sub-block of a layer.\n", - " False: use layer_norm after each sub-block of a layer.\n", - " concat_after (bool): whether to concat attention layer's input\n", - " and output.\n", - " True: x -> x + linear(concat(x, att(x)))\n", - " False: x -> x + att(x)\n", - " static_chunk_size (int): chunk size for static chunk training and\n", - " decoding\n", - " use_dynamic_chunk (bool): whether use dynamic chunk size for\n", - " training or not, You can only use fixed chunk(chunk_size > 0)\n", - " or dyanmic chunk size(use_dynamic_chunk = True)\n", - " global_cmvn (Optional[paddle.nn.Layer]): Optional GlobalCMVN layer\n", - " use_dynamic_left_chunk (bool): whether use dynamic left chunk in\n", - " dynamic chunk training\n", - " \"\"\"\n", - " assert check_argument_types()\n", - " super().__init__()\n", - " self._output_size = output_size\n", - "\n", - " if pos_enc_layer_type == \"abs_pos\":\n", - " pos_enc_class = PositionalEncoding\n", - " elif pos_enc_layer_type == \"rel_pos\":\n", - " pos_enc_class = RelPositionalEncoding\n", - " else:\n", - " raise ValueError(\"unknown pos_enc_layer: \" + pos_enc_layer_type)\n", - "\n", - " if input_layer == \"linear\":\n", - " subsampling_class = LinearNoSubsampling\n", - " elif input_layer == \"conv2d\":\n", - " subsampling_class = Conv2dSubsampling4\n", - " elif input_layer == \"conv2d6\":\n", - " subsampling_class = Conv2dSubsampling6\n", - " elif input_layer == \"conv2d8\":\n", - " subsampling_class = Conv2dSubsampling8\n", - " else:\n", - " raise ValueError(\"unknown input_layer: \" + input_layer)\n", - "\n", - " self.global_cmvn = global_cmvn\n", - " self.embed = subsampling_class(\n", - " idim=input_size,\n", - " odim=output_size,\n", - " dropout_rate=dropout_rate,\n", - " pos_enc_class=pos_enc_class(\n", - " d_model=output_size, dropout_rate=positional_dropout_rate), )\n", - "\n", - " self.normalize_before = normalize_before\n", - " self.after_norm = nn.LayerNorm(output_size, epsilon=1e-12)\n", - " self.static_chunk_size = static_chunk_size\n", - " self.use_dynamic_chunk = use_dynamic_chunk\n", - " self.use_dynamic_left_chunk = use_dynamic_left_chunk\n", - "\n", - " def output_size(self) -> int:\n", - " return self._output_size\n", - "\n", - " def forward(\n", - " self,\n", - " xs: paddle.Tensor,\n", - " xs_lens: paddle.Tensor,\n", - " decoding_chunk_size: int=0,\n", - " num_decoding_left_chunks: int=-1,\n", - " ) -> Tuple[paddle.Tensor, paddle.Tensor]:\n", - " \"\"\"Embed positions in tensor.\n", - " Args:\n", - " xs: padded input tensor (B, L, D)\n", - " xs_lens: input length (B)\n", - " decoding_chunk_size: decoding chunk size for dynamic chunk\n", - " 0: default for training, use random dynamic chunk.\n", - " <0: for decoding, use full chunk.\n", - " >0: for decoding, use fixed chunk size as set.\n", - " num_decoding_left_chunks: number of left chunks, this is for decoding,\n", - " the chunk size is decoding_chunk_size.\n", - " >=0: use num_decoding_left_chunks\n", - " <0: use all left chunks\n", - " Returns:\n", - " encoder output tensor, lens and mask\n", - " \"\"\"\n", - " masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)\n", - "\n", - " if self.global_cmvn is not None:\n", - " xs = self.global_cmvn(xs)\n", - " #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor\n", - " xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)\n", - " #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor\n", - " masks = masks.astype(paddle.bool)\n", - " #TODO(Hui Zhang): mask_pad = ~masks\n", - " mask_pad = masks.logical_not()\n", - " chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,\n", - " decoding_chunk_size, self.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - " for layer in self.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " if self.normalize_before:\n", - " xs = self.after_norm(xs)\n", - " # Here we assume the mask is not changed in encoder layers, so just\n", - " # return the masks before encoder layers, and the masks will be used\n", - " # for cross attention with decoder later\n", - " return xs, masks" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "sharp-municipality", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "class ConformerEncoder(BaseEncoder):\n", - " \"\"\"Conformer encoder module.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " input_size: int,\n", - " output_size: int=256,\n", - " attention_heads: int=4,\n", - " linear_units: int=2048,\n", - " num_blocks: int=6,\n", - " dropout_rate: float=0.1,\n", - " positional_dropout_rate: float=0.1,\n", - " attention_dropout_rate: float=0.0,\n", - " input_layer: str=\"conv2d\",\n", - " pos_enc_layer_type: str=\"rel_pos\",\n", - " normalize_before: bool=True,\n", - " concat_after: bool=False,\n", - " static_chunk_size: int=0,\n", - " use_dynamic_chunk: bool=False,\n", - " global_cmvn: nn.Layer=None,\n", - " use_dynamic_left_chunk: bool=False,\n", - " positionwise_conv_kernel_size: int=1,\n", - " macaron_style: bool=True,\n", - " selfattention_layer_type: str=\"rel_selfattn\",\n", - " activation_type: str=\"swish\",\n", - " use_cnn_module: bool=True,\n", - " cnn_module_kernel: int=15,\n", - " causal: bool=False,\n", - " cnn_module_norm: str=\"batch_norm\", ):\n", - " \"\"\"Construct ConformerEncoder\n", - " Args:\n", - " input_size to use_dynamic_chunk, see in BaseEncoder\n", - " positionwise_conv_kernel_size (int): Kernel size of positionwise\n", - " conv1d layer.\n", - " macaron_style (bool): Whether to use macaron style for\n", - " positionwise layer.\n", - " selfattention_layer_type (str): Encoder attention layer type,\n", - " the parameter has no effect now, it's just for configure\n", - " compatibility.\n", - " activation_type (str): Encoder activation function type.\n", - " use_cnn_module (bool): Whether to use convolution module.\n", - " cnn_module_kernel (int): Kernel size of convolution module.\n", - " causal (bool): whether to use causal convolution or not.\n", - " cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm']\n", - " \"\"\"\n", - " assert check_argument_types()\n", - " super().__init__(input_size, output_size, attention_heads, linear_units,\n", - " num_blocks, dropout_rate, positional_dropout_rate,\n", - " attention_dropout_rate, input_layer,\n", - " pos_enc_layer_type, normalize_before, concat_after,\n", - " static_chunk_size, use_dynamic_chunk, global_cmvn,\n", - " use_dynamic_left_chunk)\n", - " activation = get_activation(activation_type)\n", - "\n", - " # self-attention module definition\n", - " encoder_selfattn_layer = RelPositionMultiHeadedAttention\n", - " encoder_selfattn_layer_args = (attention_heads, output_size,\n", - " attention_dropout_rate)\n", - " # feed-forward module definition\n", - " positionwise_layer = PositionwiseFeedForward\n", - " positionwise_layer_args = (output_size, linear_units, dropout_rate,\n", - " activation)\n", - " # convolution module definition\n", - " convolution_layer = ConvolutionModule\n", - " convolution_layer_args = (output_size, cnn_module_kernel, activation,\n", - " cnn_module_norm, causal)\n", - "\n", - " self.encoders = nn.LayerList([\n", - " ConformerEncoderLayer(\n", - " size=output_size,\n", - " self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),\n", - " feed_forward=positionwise_layer(*positionwise_layer_args),\n", - " feed_forward_macaron=positionwise_layer(\n", - " *positionwise_layer_args) if macaron_style else None,\n", - " conv_module=convolution_layer(*convolution_layer_args)\n", - " if use_cnn_module else None,\n", - " dropout_rate=dropout_rate,\n", - " normalize_before=normalize_before,\n", - " concat_after=concat_after) for _ in range(num_blocks)\n", - " ])\n" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "tutorial-syndication", - "metadata": {}, - "outputs": [], - "source": [ - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.modules.cmvn import GlobalCMVN\n", - "\n", - "configs=cfg.model\n", - "mean, istd = load_cmvn(configs['cmvn_file'],\n", - " configs['cmvn_file_type'])\n", - "global_cmvn = GlobalCMVN(\n", - " paddle.to_tensor(mean, dtype=paddle.float),\n", - " paddle.to_tensor(istd, dtype=paddle.float))\n", - "\n", - "\n", - "input_dim = configs['input_dim']\n", - "vocab_size = configs['output_dim']\n", - "encoder_type = configs.get('encoder', 'transformer')\n", - " \n", - "encoder = ConformerEncoder(\n", - " input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "fuzzy-register", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] - } - ], - "source": [ - "o = global_cmvn(feat)\n", - "o2 = model.encoder.global_cmvn(feat)\n", - "print(np.allclose(o.numpy(), o2.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "explicit-triumph", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "humanitarian-belgium", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dying-proposal", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "honest-quick", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bound-cholesterol", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "viral-packaging", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 203, - "id": "balanced-locator", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 1, 207], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[True , True , True , ..., True , True , True ]],\n", - "\n", - " [[True , True , True , ..., True , True , True ]],\n", - "\n", - " [[True , True , True , ..., True , False, False]],\n", - "\n", - " ...,\n", - "\n", - " [[True , True , True , ..., False, False, False]],\n", - "\n", - " [[True , True , True , ..., False, False, False]],\n", - "\n", - " [[True , True , True , ..., False, False, False]]])\n" - ] - } - ], - "source": [ - "from deepspeech.modules.mask import make_non_pad_mask\n", - "from deepspeech.modules.mask import make_pad_mask\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "print(masks)" - ] - }, - { - "cell_type": "code", - "execution_count": 204, - "id": "induced-proposition", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 207, 80], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[-0.53697914, -0.19910523, -0.34997201, ..., -0.82427669, -1.02650309, -0.96300691],\n", - " [-0.04464225, 0.23176001, -0.32538742, ..., -0.90158713, -1.03248465, -0.75986791],\n", - " [ 0.50035292, 0.22691160, -0.73052198, ..., -1.00552964, -0.87123060, -1.03062117],\n", - " ...,\n", - " [-0.40023831, -0.14325078, -0.57947433, ..., -1.07178426, -1.28059900, -1.05180073],\n", - " [ 0.15755332, -0.00184949, -0.28702953, ..., -1.10898709, -0.94518697, -0.72506356],\n", - " [-0.47520429, -1.39415145, -0.25754252, ..., -1.13649082, -1.19430351, -1.22903371]],\n", - "\n", - " [[ 0.95454037, 0.36427975, -1.38908529, ..., -1.16366839, -1.28453600, -1.20151031],\n", - " [-0.08573537, -1.05785275, -0.89172721, ..., -0.96440506, -1.12547100, -1.25990939],\n", - " [ 0.47653601, 0.32886592, -0.59200549, ..., -1.19421589, -1.14302588, -1.02422845],\n", - " ...,\n", - " [-0.47431335, -0.33558893, -0.72325647, ..., -1.45058632, -1.39574063, -1.04641151],\n", - " [ 0.36112556, 0.10380996, -1.15994537, ..., -1.04394984, -1.02212358, -1.02083635],\n", - " [-1.27172923, -2.14601755, -0.75676596, ..., -0.97822225, -0.93785471, -1.03707945]],\n", - "\n", - " [[-1.54652190, -1.01517177, -0.88900733, ..., -0.48522446, -0.75163364, -0.67765164],\n", - " [-0.76100892, -0.73351598, -0.91587651, ..., -0.24835993, -0.58927339, -0.73722762],\n", - " [-0.02471367, 0.17015894, -0.42326337, ..., -0.33203802, -0.76695800, -0.71651691],\n", - " ...,\n", - " [-1.70319796, -1.25910866, -1.14492917, ..., -1.18101490, -1.11631835, -0.93108195],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.64982772, 0.26116797, -0.84196597, ..., -0.87213463, -1.10728693, -1.32531130],\n", - " [ 0.35391113, -0.01584581, -0.40424931, ..., -0.99173468, -1.07270539, -1.19239008],\n", - " [ 0.37704495, -0.06278508, -0.11467686, ..., -1.10212946, -1.09524000, -1.11815071],\n", - " ...,\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", - "\n", - " [[ 0.04445776, -0.17546852, -0.67475224, ..., -0.49801198, -0.56782746, -0.77852231],\n", - " [-1.34279025, -0.80342549, -0.90457231, ..., -0.65901577, -0.72549772, -0.62796098],\n", - " [-0.76252806, -0.13071291, -0.13280024, ..., -0.56132573, -0.60587686, -0.72114766],\n", - " ...,\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", - "\n", - " [[-1.07980299, -1.08341801, -1.17969072, ..., -0.17757270, -0.43746525, -0.04000654],\n", - " [ 0.92353648, 0.63770926, -0.52810186, ..., -0.12927933, -0.20342292, 0.16655664],\n", - " [ 0.49337494, -0.00911332, -0.73301607, ..., 0.10074048, -0.09811471, -0.00923573],\n", - " ...,\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]]])\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "print(xs)" - ] - }, - { - "cell_type": "code", - "execution_count": 205, - "id": "cutting-julian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 256, 51, 19], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0.00209083],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0.01194306, 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0.04610471, 0. ],\n", - " [0. , 0. , 0. , ..., 0.00967231, 0.04613467, 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.22816099, 0.24614786, 0.25304127, ..., 0.20401822, 0.23248228, 0.31190544],\n", - " [0.13587360, 0.28877240, 0.27991283, ..., 0.19210319, 0.20346391, 0.19934426],\n", - " [0.25739068, 0.39348233, 0.27877361, ..., 0.27482539, 0.19302306, 0.23810163],\n", - " ...,\n", - " [0.11939213, 0.28473237, 0.33082074, ..., 0.23838061, 0.22104350, 0.23905794],\n", - " [0.17387670, 0.20402060, 0.40263173, ..., 0.24782266, 0.26742202, 0.15426503],\n", - " [0. , 0.29080707, 0.27725950, ..., 0.17539823, 0.18478745, 0.22483408]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.35446781, 0.38861471, 0.39724261, ..., 0.38680089, 0.33568040, 0.34552398],\n", - " [0.41739127, 0.51038563, 0.41729912, ..., 0.33992639, 0.37081629, 0.35109508],\n", - " [0.36116859, 0.40744874, 0.48490953, ..., 0.34848654, 0.32321057, 0.35188958],\n", - " ...,\n", - " [0.23143977, 0.38021481, 0.51526314, ..., 0.36499465, 0.37411752, 0.39986172],\n", - " [0.34678638, 0.40238205, 0.50076538, ..., 0.36184520, 0.31596646, 0.36334658],\n", - " [0.36498138, 0.37943166, 0.51718897, ..., 0.31798238, 0.33656698, 0.34130475]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.01456045, 0.09447514, 0. , ..., 0. , 0. , 0. ],\n", - " [0.01500242, 0.02963220, 0. , ..., 0. , 0. , 0. ],\n", - " [0.03295187, 0. , 0. , ..., 0.04584959, 0.02043908, 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0.04425837],\n", - " [0. , 0. , 0.02556529, ..., 0. , 0.00900441, 0.04908358]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.11141267, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.33696529, 0.38526866, 0.32900479, ..., 0.28703830, 0.23351061, 0.19004467],\n", - " [0.13575366, 0.35783342, 0.33573425, ..., 0.22081660, 0.15854910, 0.13587447],\n", - " [0.21928655, 0.28900093, 0.28255141, ..., 0.20602837, 0.23927397, 0.21909429],\n", - " ...,\n", - " [0.23291890, 0.39096734, 0.36399242, ..., 0.20598020, 0.25373828, 0.23137446],\n", - " [0.18739152, 0.30793777, 0.30296701, ..., 0.27250600, 0.25191751, 0.20836820],\n", - " [0.22454213, 0.41402060, 0.54082996, ..., 0.31874508, 0.25079906, 0.25938687]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.26456982, 0.49519050, 0.56702250, ..., 0.30954638, 0.35292268, 0.32668519],\n", - " [0.21576807, 0.51833367, 0.49183372, ..., 0.36043224, 0.38523889, 0.36154741],\n", - " [0.20067888, 0.42784205, 0.52817714, ..., 0.31871423, 0.32452232, 0.31036487],\n", - " ...,\n", - " [0.49855131, 0.51001430, 0.52278662, ..., 0.36450142, 0.34338164, 0.33602941],\n", - " [0.41233343, 0.55517823, 0.52827710, ..., 0.40675971, 0.33873138, 0.36724189],\n", - " [0.40820011, 0.46187383, 0.47338152, ..., 0.38690975, 0.36039269, 0.38022059]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0.00578516, 0. , ..., 0.00748384, 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0.03035110, 0. , 0.00026720],\n", - " [0.00094807, 0. , 0. , ..., 0.00795512, 0. , 0. ],\n", - " ...,\n", - " [0.02032628, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0.01080076, 0. ],\n", - " [0.18470290, 0. , 0. , ..., 0.05058352, 0.09475817, 0.05914564]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.38708323, 0.28021947, 0.35892880, ..., 0.16595127, 0.16031364, 0.21136315],\n", - " [0.15595171, 0.30544323, 0.24666184, ..., 0.22675267, 0.25765014, 0.19682154],\n", - " [0.29517862, 0.41209796, 0.20063159, ..., 0.17595036, 0.22536841, 0.22214051],\n", - " ...,\n", - " [0.24744980, 0.26258564, 0.38654143, ..., 0.23620218, 0.23157144, 0.18514194],\n", - " [0.25714791, 0.29592845, 0.47744542, ..., 0.23545510, 0.25072727, 0.20976165],\n", - " [1.20154655, 0.84644288, 0.73385584, ..., 1.02517247, 0.95309550, 1.00134516]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.45013186, 0.47484034, 0.40540054, ..., 0.19346163, 0.17825794, 0.14776605],\n", - " [0.47545874, 0.48186573, 0.36760187, ..., 0.27809089, 0.32997063, 0.32337096],\n", - " [0.46160024, 0.40050328, 0.39060861, ..., 0.36612910, 0.35242686, 0.29738861],\n", - " ...,\n", - " [0.55148494, 0.51017821, 0.40132499, ..., 0.38948193, 0.35737294, 0.33088297],\n", - " [0.41972569, 0.45475486, 0.45320493, ..., 0.38343129, 0.40125814, 0.36180776],\n", - " [0.34279808, 0.31606171, 0.44701228, ..., 0.21665487, 0.23984617, 0.23903391]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.04178291, 0. , 0.01580476, ..., 0. , 0.02250817, 0. ],\n", - " [0.04323414, 0.07786420, 0. , ..., 0.01634724, 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.03209178, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.13563479, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0. , 0.25187218, 0.24979387, ..., 0.24774717, 0.22354351, 0.19149347],\n", - " [0.16540922, 0.19585510, 0.19812922, ..., 0.27344131, 0.20928150, 0.26150429],\n", - " [0.10494646, 0.06329897, 0.33843631, ..., 0.25138417, 0.12470355, 0.23926635],\n", - " ...,\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.11428106, 0.45667490, 0.46820879, ..., 0.32057840, 0.33578536, 0.39012644],\n", - " [0.10441341, 0.45739070, 0.46107352, ..., 0.38467997, 0.38291249, 0.36685589],\n", - " [0.19867736, 0.35519636, 0.44313061, ..., 0.40679252, 0.38067645, 0.30645671],\n", - " ...,\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.02465414, 0. , 0. , ..., 0. , 0. , 0.03390232],\n", - " [0. , 0. , 0.01830704, ..., 0.05166877, 0.00948385, 0.07453502],\n", - " [0.09921519, 0. , 0.01587192, ..., 0.01620276, 0.05140074, 0.00192392],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.40034360, 0.25306445, 0.20217699, ..., 0.09816189, 0.07064310, 0.04974059],\n", - " [0.12567598, 0.21030979, 0.11181555, ..., 0.04278110, 0.11968569, 0.12005232],\n", - " [0.28786880, 0.24030517, 0.22565845, ..., 0. , 0.06418110, 0.05872961],\n", - " ...,\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.38404641, 0.30990323, 0.37156230, ..., 0.18125033, 0.15050662, 0.19619957],\n", - " [0.47285745, 0.40528792, 0.39718056, ..., 0.24709940, 0.04565683, 0.11500744],\n", - " [0.32620737, 0.30072594, 0.30477354, ..., 0.23529193, 0.21356541, 0.16985542],\n", - " ...,\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.03343770, 0.00123780, 0.05297198, ..., 0.07271163, 0.08656286, 0.14493589],\n", - " [0.11043239, 0.06143146, 0.06362963, ..., 0.08127750, 0.06259022, 0.08315435],\n", - " [0.01767678, 0.00201111, 0.07875030, ..., 0.06963293, 0.08979890, 0.05326346],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.10033827, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.15627117, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.05144687, 0. , 0. , ..., 0. , 0. , 0.00436414],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.25142455, 0.45964020, 0.37346074, ..., 0.04763087, 0. , 0. ],\n", - " [0.19760093, 0.26626948, 0.11190540, ..., 0.03044968, 0. , 0. ],\n", - " [0.16340607, 0.32938001, 0.25689697, ..., 0.05569421, 0. , 0. ],\n", - " ...,\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0.02218930, 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0.02848953],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.25810039, 0.63016868, 0.37037861, ..., 0.18704373, 0.08269356, 0.09912672],\n", - " [0.17292863, 0.50678611, 0.40738991, ..., 0.16006103, 0.11725381, 0.09940521],\n", - " [0.24175072, 0.41616210, 0.41256818, ..., 0.13519743, 0.07912572, 0.12846369],\n", - " ...,\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]]])\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "\n", - "#xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "# print(xs)\n", - "\n", - "x = xs.unsqueeze(1)\n", - "x = model.encoder.embed.conv(x)\n", - "print(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 206, - "id": "friendly-nightlife", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[[-0.03426375, 0.14291267, -0.06718873, ..., 0.09064753, 0.01809387, -0.04340880],\n", - " [-0.05007839, 0.11054724, -0.10399298, ..., 0.11457238, 0.04244684, -0.01249714],\n", - " [-0.10695291, 0.16910909, -0.08352133, ..., 0.07710276, 0.01168563, -0.03584499],\n", - " ...,\n", - " [-0.06060536, 0.14455931, -0.05470302, ..., 0.05364908, 0.03033342, -0.02610814],\n", - " [-0.08505894, 0.13611752, -0.11132983, ..., 0.13079923, 0.01580139, -0.02281028],\n", - " [-0.10604677, 0.14714901, -0.10885533, ..., 0.08543444, 0.03719445, -0.04634233]],\n", - "\n", - " [[-0.12392755, 0.14486063, -0.05674079, ..., 0.02573164, 0.03128851, 0.00545091],\n", - " [-0.04775286, 0.08473608, -0.08507854, ..., 0.04573154, 0.04240163, 0.01053247],\n", - " [-0.05940291, 0.10023535, -0.08143730, ..., 0.03596500, 0.01673085, 0.02089563],\n", - " ...,\n", - " [-0.09222981, 0.15823206, -0.07700447, ..., 0.08122957, 0.03136991, -0.00646474],\n", - " [-0.07331756, 0.14482647, -0.07838815, ..., 0.10869440, 0.01356864, -0.02777974],\n", - " [-0.07937264, 0.20143102, -0.05544947, ..., 0.10287814, 0.00608235, -0.04799180]],\n", - "\n", - " [[-0.03670349, 0.08931590, -0.08718812, ..., 0.01314050, 0.00642052, 0.00573716],\n", - " [ 0.01089254, 0.11146393, -0.10263617, ..., 0.05070438, 0.01960694, 0.03521532],\n", - " [-0.02182280, 0.11443964, -0.06678198, ..., 0.04327708, 0.00861394, 0.02871092],\n", - " ...,\n", - " [-0.06792898, 0.14376275, -0.07899005, ..., 0.11248926, 0.03208683, -0.03264240],\n", - " [-0.07884051, 0.17024788, -0.08583611, ..., 0.09028331, 0.03588808, -0.02075090],\n", - " [-0.13792302, 0.27163863, -0.23930418, ..., 0.13391261, 0.07521040, -0.08621951]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.02446348, 0.11595841, -0.03591986, ..., 0.06288970, 0.02895011, -0.06532725],\n", - " [-0.05378424, 0.12607370, -0.09023033, ..., 0.09078894, 0.01035743, 0.03701983],\n", - " [-0.04566649, 0.14275314, -0.06686870, ..., 0.09890588, -0.00612222, 0.03439377],\n", - " ...,\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]],\n", - "\n", - " [[-0.01012144, 0.03909408, -0.07077143, ..., 0.00452683, -0.01377654, 0.02897627],\n", - " [-0.00519154, 0.03594019, -0.06831125, ..., 0.05693541, -0.00406374, 0.04561640],\n", - " [-0.01762631, 0.00500899, -0.05886075, ..., 0.02112178, -0.00729015, 0.02782153],\n", - " ...,\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]],\n", - "\n", - " [[-0.03411558, -0.04318277, -0.08497842, ..., -0.04886402, 0.04296734, 0.06151697],\n", - " [ 0.00263296, -0.06913657, -0.08993219, ..., -0.00149064, 0.05696633, 0.03304394],\n", - " [-0.01818341, -0.01178640, -0.09679577, ..., -0.00870231, 0.00362198, 0.01916483],\n", - " ...,\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]]])\n", - "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[[-0.54821998, 2.28660274, -1.07501972, ..., 1.45036042, 0.28950194, -0.69454080],\n", - " [-0.80125421, 1.76875579, -1.66388774, ..., 1.83315802, 0.67914939, -0.19995420],\n", - " [-1.71124649, 2.70574546, -1.33634126, ..., 1.23364413, 0.18697014, -0.57351983],\n", - " ...,\n", - " [-0.96968573, 2.31294894, -0.87524825, ..., 0.85838526, 0.48533469, -0.41773027],\n", - " [-1.36094308, 2.17788029, -1.78127730, ..., 2.09278774, 0.25282228, -0.36496443],\n", - " [-1.69674826, 2.35438418, -1.74168527, ..., 1.36695099, 0.59511113, -0.74147725]],\n", - "\n", - " [[-1.98284078, 2.31777000, -0.90785271, ..., 0.41170627, 0.50061619, 0.08721463],\n", - " [-0.76404583, 1.35577726, -1.36125672, ..., 0.73170459, 0.67842603, 0.16851945],\n", - " [-0.95044655, 1.60376561, -1.30299675, ..., 0.57544005, 0.26769355, 0.33433008],\n", - " ...,\n", - " [-1.47567701, 2.53171301, -1.23207152, ..., 1.29967308, 0.50191855, -0.10343577],\n", - " [-1.17308092, 2.31722355, -1.25421047, ..., 1.73911047, 0.21709818, -0.44447583],\n", - " [-1.26996231, 3.22289634, -0.88719147, ..., 1.64605021, 0.09731755, -0.76786882]],\n", - "\n", - " [[-0.58725590, 1.42905438, -1.39500988, ..., 0.21024795, 0.10272825, 0.09179455],\n", - " [ 0.17428070, 1.78342295, -1.64217877, ..., 0.81127012, 0.31371105, 0.56344515],\n", - " [-0.34916472, 1.83103430, -1.06851172, ..., 0.69243336, 0.13782299, 0.45937473],\n", - " ...,\n", - " [-1.08686376, 2.30020404, -1.26384079, ..., 1.79982817, 0.51338923, -0.52227837],\n", - " [-1.26144814, 2.72396612, -1.37337780, ..., 1.44453299, 0.57420933, -0.33201432],\n", - " [-2.20676827, 4.34621811, -3.82886696, ..., 2.14260173, 1.20336640, -1.37951219]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.39141566, 1.85533464, -0.57471782, ..., 1.00623512, 0.46320182, -1.04523599],\n", - " [-0.86054784, 2.01717925, -1.44368529, ..., 1.45262301, 0.16571884, 0.59231722],\n", - " [-0.73066384, 2.28405023, -1.06989920, ..., 1.58249414, -0.09795550, 0.55030036],\n", - " ...,\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]],\n", - "\n", - " [[-0.16194311, 0.62550521, -1.13234293, ..., 0.07242929, -0.22042468, 0.46362036],\n", - " [-0.08306468, 0.57504302, -1.09298003, ..., 0.91096652, -0.06501988, 0.72986233],\n", - " [-0.28202093, 0.08014385, -0.94177192, ..., 0.33794850, -0.11664233, 0.44514441],\n", - " ...,\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]],\n", - "\n", - " [[-0.54584920, -0.69092435, -1.35965478, ..., -0.78182435, 0.68747747, 0.98427159],\n", - " [ 0.04212743, -1.10618520, -1.43891501, ..., -0.02385022, 0.91146135, 0.52870303],\n", - " [-0.29093450, -0.18858244, -1.54873240, ..., -0.13923697, 0.05795169, 0.30663735],\n", - " ...,\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]]])\n", - "Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[ 0. , 1. , 0. , ..., 1. , 0. , 1. ],\n", - " [ 0.84147102, 0.54030228, 0.80196184, ..., 1. , 0.00010746, 1. ],\n", - " [ 0.90929747, -0.41614681, 0.95814437, ..., 1. , 0.00021492, 1. ],\n", - " ...,\n", - " [-0.76825470, -0.64014435, 0.63279730, ..., 0.99998462, 0.00515809, 0.99998671],\n", - " [-0.95375264, 0.30059254, 0.99899054, ..., 0.99998397, 0.00526555, 0.99998611],\n", - " [-0.26237485, 0.96496606, 0.56074661, ..., 0.99998331, 0.00537301, 0.99998558]]])\n" - ] - } - ], - "source": [ - "b, c, t, f = paddle.shape(x)\n", - "x = model.encoder.embed.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))\n", - "print(x)\n", - "x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n", - "print(x)\n", - "print(pos_emb)" - ] - }, - { - "cell_type": "code", - "execution_count": 207, - "id": "guilty-cache", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[ 0. , 1. , 0. , ..., 1. , 0. , 1. ],\n", - " [ 0.84147102, 0.54030228, 0.80196184, ..., 1. , 0.00010746, 1. ],\n", - " [ 0.90929747, -0.41614681, 0.95814437, ..., 1. , 0.00021492, 1. ],\n", - " ...,\n", - " [-0.76825470, -0.64014435, 0.63279730, ..., 0.99998462, 0.00515809, 0.99998671],\n", - " [-0.95375264, 0.30059254, 0.99899054, ..., 0.99998397, 0.00526555, 0.99998611],\n", - " [-0.26237485, 0.96496606, 0.56074661, ..., 0.99998331, 0.00537301, 0.99998558]]])\n" - ] - } - ], - "source": [ - "print(pos_emb)" - ] - }, - { - "cell_type": "code", - "execution_count": 208, - "id": "iraqi-payday", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[ 0.0000000e+00 1.0000000e+00 0.0000000e+00 ... 1.0000000e+00\n", - " 0.0000000e+00 1.0000000e+00]\n", - " [ 8.4147096e-01 5.4030234e-01 8.0196178e-01 ... 1.0000000e+00\n", - " 1.0746076e-04 1.0000000e+00]\n", - " [ 9.0929741e-01 -4.1614684e-01 9.5814437e-01 ... 1.0000000e+00\n", - " 2.1492151e-04 1.0000000e+00]\n", - " ...\n", - " [ 9.5625257e-01 -2.9254240e-01 4.8925215e-01 ... 8.3807874e-01\n", - " 5.1154459e-01 8.5925674e-01]\n", - " [ 2.7049953e-01 -9.6272010e-01 9.9170387e-01 ... 8.3801574e-01\n", - " 5.1163691e-01 8.5920173e-01]\n", - " [-6.6394955e-01 -7.4777740e-01 6.9544029e-01 ... 8.3795273e-01\n", - " 5.1172924e-01 8.5914677e-01]]]\n", - "[1, 5000, 256]\n" - ] - } - ], - "source": [ - "import torch\n", - "import math\n", - "import numpy as np\n", - "\n", - "max_len=5000\n", - "d_model=256\n", - "\n", - "pe = torch.zeros(max_len, d_model)\n", - "position = torch.arange(0, max_len,\n", - " dtype=torch.float32).unsqueeze(1)\n", - "toruch_position = position\n", - "div_term = torch.exp(\n", - " torch.arange(0, d_model, 2, dtype=torch.float32) *\n", - " -(math.log(10000.0) / d_model))\n", - "tourch_div_term = div_term.cpu().detach().numpy()\n", - "\n", - "torhc_sin = torch.sin(position * div_term)\n", - "torhc_cos = torch.cos(position * div_term)\n", - "\n", - "np_sin = np.sin((position * div_term).cpu().detach().numpy())\n", - "np_cos = np.cos((position * div_term).cpu().detach().numpy())\n", - "pe[:, 0::2] = torhc_sin\n", - "pe[:, 1::2] = torhc_cos\n", - "pe = pe.unsqueeze(0) \n", - "tourch_pe = pe.cpu().detach().numpy()\n", - "print(tourch_pe)\n", - "bak_pe = model.encoder.embed.pos_enc.pe\n", - "print(bak_pe.shape)\n", - "model.encoder.embed.pos_enc.pe = paddle.to_tensor(tourch_pe)" - ] - }, - { - "cell_type": "code", - "execution_count": 210, - "id": "exempt-cloud", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "#print(xs)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "composite-involvement", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 269, - "id": "handed-harris", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "True\n", - "True\n", - "True\n", - "True\n", - "True\n", - "False\n", - "True\n", - "[256, 2048]\n", - "[2048]\n", - "[2048, 256]\n", - "[256]\n", - "--------ff-------\n", - "True\n", - "False\n", - "False\n", - "False\n", - "False\n", - "True\n", - "linear_714.w_0 True\n", - "linear_714.b_0 True\n", - "linear_715.w_0 True\n", - "linear_715.b_0 True\n", - "False\n", - "True\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "masks = masks.astype(paddle.bool)\n", - "mask_pad = masks.logical_not()\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", - " decoding_chunk_size, model.encoder.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "#print(chunk_masks)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "torch_chunk_masks = data['chunk_masks']\n", - "torch_mask_pad = data['mask_pad']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", - "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", - "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", - "\n", - "for layer in model.encoder.encoders:\n", - " #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " print(layer.feed_forward_macaron is not None)\n", - " print(layer.normalize_before)\n", - " \n", - " data = np.load('.notebook/enc_0_norm_ff.npz')\n", - " t_norm_ff = data['norm_ff']\n", - " t_xs = data['xs']\n", - " \n", - " \n", - " x = xs\n", - " print(np.allclose(t_xs, x.numpy()))\n", - " residual = x\n", - " print(np.allclose(t_xs, residual.numpy()))\n", - " x_nrom = layer.norm_ff_macaron(x)\n", - " print(np.allclose(t.numpy(), x_nrom.numpy()))\n", - " print(np.allclose(t_norm_ff, x_nrom.numpy()))\n", - "# for n, p in layer.norm_ff_macaron.state_dict().items():\n", - "# print(n, p)\n", - "# pass\n", - "\n", - " layer.eval()\n", - " x_nrom = paddle.to_tensor(t_norm_ff)\n", - " print(np.allclose(t_norm_ff, x_nrom.numpy()))\n", - " x = residual + layer.ff_scale * layer.feed_forward_macaron(x_nrom)\n", - " \n", - " ps=[]\n", - " for n, p in layer.feed_forward_macaron.state_dict().items():\n", - " #print(n, p)\n", - " ps.append(p)\n", - " print(p.shape)\n", - " pass\n", - "\n", - " x_nrom = paddle.to_tensor(t_norm_ff)\n", - " ff_l_x = layer.feed_forward_macaron.w_1(x_nrom)\n", - " ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n", - " ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n", - " data = np.load('.notebook/enc_0_ff_out.npz', allow_pickle=True)\n", - " t_norm_ff = data['norm_ff']\n", - " t_ff_out = data['ff_out']\n", - " t_ff_l_x = data['ff_l_x']\n", - " t_ff_l_a_x = data['ff_l_a_x']\n", - " t_ff_l_a_l_x = data['ff_l_a_l_x']\n", - " t_ps = data['ps']\n", - " \n", - " print(\"--------ff-------\")\n", - " print(np.allclose(x_nrom.numpy(), t_norm_ff))\n", - " print(np.allclose(x.numpy(), t_ff_out))\n", - " print(np.allclose(ff_l_x.numpy(), t_ff_l_x))\n", - " print(np.allclose(ff_l_a_x.numpy(), t_ff_l_a_x))\n", - " print(np.allclose(ff_l_a_l_x.numpy(), t_ff_l_a_l_x))\n", - " \n", - " print(np.allclose(ff_l_x.numpy(), t_ff_l_x, atol=1e-6))\n", - " for p, t_p in zip(ps, t_ps):\n", - " print(p.name, np.allclose(p.numpy(), t_p.T))\n", - " \n", - " \n", - "# residual = x\n", - "# x = layer.norm_mha(x)\n", - "# x_q = x\n", - " \n", - " data = np.load('.notebook/enc_0_selattn_out.npz', allow_pickle=True)\n", - " tx_q = data['x_q']\n", - " tx = data['x']\n", - " tpos_emb=data['pos_emb']\n", - " tmask=data['mask']\n", - " tt_x_att=data['x_att']\n", - " x_q = paddle.to_tensor(tx_q)\n", - " x = paddle.to_tensor(tx)\n", - " pos_emb = paddle.to_tensor(tpos_emb)\n", - " mask = paddle.to_tensor(tmask)\n", - " \n", - " x_att = layer.self_attn(x_q, x, x, pos_emb, mask)\n", - " print(np.allclose(x_att.numpy(), t_x_att))\n", - " print(np.allclose(x_att.numpy(), t_x_att, atol=1e-6))\n", - " \n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 270, - "id": "sonic-thumb", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "False\n", - "True\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "masks = masks.astype(paddle.bool)\n", - "mask_pad = masks.logical_not()\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", - " decoding_chunk_size, model.encoder.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "#print(chunk_masks)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "torch_chunk_masks = data['chunk_masks']\n", - "torch_mask_pad = data['mask_pad']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", - "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", - "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", - "\n", - "\n", - "for layer in model.encoder.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " break\n", - "data = np.load('.notebook/enc_0.npz')\n", - "torch_xs = data['enc_0']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(xs.numpy(), torch_xs, atol=1e-6))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 273, - "id": "brave-latino", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "--------layers_______\n", - "False\n", - "True\n", - "[[-0.70194244 0.56254214 0.6880346 ... 1.1237319 0.7803924\n", - " 1.1369387 ]\n", - " [-0.7787783 0.3912667 0.71887773 ... 1.251882 0.886168\n", - " 1.3173451 ]\n", - " [-0.95908964 0.6346029 0.87671334 ... 0.98183745 0.7440111\n", - " 1.2903278 ]\n", - " ...\n", - " [-1.0732255 0.67236906 0.92303115 ... 0.9075458 0.8176712\n", - " 1.3239655 ]\n", - " [-1.1654118 0.6819967 0.6939453 ... 1.2238353 0.8028295\n", - " 1.4506507 ]\n", - " [-1.2732092 0.7145806 0.75819594 ... 0.94154835 0.8774845\n", - " 1.2623049 ]]\n", - "xxxxxx\n", - "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", - " 1.1369387 ]\n", - " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", - " 1.3173454 ]\n", - " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", - " 1.2903274 ]\n", - " ...\n", - " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", - " 1.3239657 ]\n", - " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", - " 1.4506509 ]\n", - " [-1.273209 0.71458095 0.75819623 ... 0.9415484 0.8774842\n", - " 1.2623055 ]]\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "masks = masks.astype(paddle.bool)\n", - "mask_pad = masks.logical_not()\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", - " decoding_chunk_size, model.encoder.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "#print(chunk_masks)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "torch_chunk_masks = data['chunk_masks']\n", - "torch_mask_pad = data['mask_pad']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", - "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", - "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", - "\n", - "print(\"--------layers_______\")\n", - "i =0\n", - "for layer in model.encoder.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " i+=1\n", - "# if i == 2:\n", - "# data = np.load('.notebook/enc_2.npz')\n", - "# torch_xs = data['enc_2']\n", - "# print(np.allclose(xs.numpy(), torch_xs))\n", - "# print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n", - "# print(xs[0].numpy())\n", - "# print('xxxxxx')\n", - "# print(torch_xs[0])\n", - "# print('----i==2')\n", - "data = np.load('.notebook/enc_all.npz')\n", - "torch_xs = data['enc_all']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n", - "print(xs[0].numpy())\n", - "print('xxxxxx')\n", - "print(torch_xs[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "id": "municipal-stock", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 278, - "id": "macro-season", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[-0.7019424 0.5625421 0.68803453 ... 1.1237317 0.7803923\n", - " 1.1369386 ]\n", - " [-0.7787783 0.39126673 0.71887773 ... 1.251882 0.886168\n", - " 1.3173451 ]\n", - " [-0.95908964 0.6346029 0.87671334 ... 0.98183745 0.7440111\n", - " 1.2903278 ]\n", - " ...\n", - " [-1.0732255 0.67236906 0.92303115 ... 0.9075458 0.8176712\n", - " 1.3239655 ]\n", - " [-1.1654117 0.68199664 0.6939452 ... 1.2238352 0.8028294\n", - " 1.4506506 ]\n", - " [-1.2732091 0.71458054 0.7581958 ... 0.9415482 0.8774844\n", - " 1.2623048 ]]\n", - "---\n", - "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", - " 1.1369387 ]\n", - " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", - " 1.3173454 ]\n", - " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", - " 1.2903274 ]\n", - " ...\n", - " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", - " 1.3239657 ]\n", - " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", - " 1.4506509 ]\n", - " [-1.2732087 0.71458083 0.7581961 ... 0.9415482 0.877484\n", - " 1.2623053 ]]\n", - "False\n", - "True\n", - "False\n" - ] - } - ], - "source": [ - "encoder_out, mask = model.encoder(feat, feat_len)\n", - "print(encoder_out.numpy()[0])\n", - "print(\"---\")\n", - "print(torch_encoder_out[0])\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5))\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "associate-sampling", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/u2_tansformer_model_espnet.ipynb b/.notebook/u2_tansformer_model_espnet.ipynb deleted file mode 100644 index 75c2ea5c6..000000000 --- a/.notebook/u2_tansformer_model_espnet.ipynb +++ /dev/null @@ -1,1672 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "choice-grade", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/DeepSpeech-2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "broke-broad", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "register user softmax to paddle, remove this when fixed!\n", - "register user log_softmax to paddle, remove this when fixed!\n", - "register user sigmoid to paddle, remove this when fixed!\n", - "register user log_sigmoid to paddle, remove this when fixed!\n", - "register user relu to paddle, remove this when fixed!\n", - "override cat of paddle if exists or register, remove this when fixed!\n", - "override item of paddle.Tensor if exists or register, remove this when fixed!\n", - "override long of paddle.Tensor if exists or register, remove this when fixed!\n", - "override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle if exists or register, remove this when fixed!\n", - "override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "register user view to paddle.Tensor, remove this when fixed!\n", - "register user view_as to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "register user fill_ to paddle.Tensor, remove this when fixed!\n", - "register user repeat to paddle.Tensor, remove this when fixed!\n", - "register user softmax to paddle.Tensor, remove this when fixed!\n", - "register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "register user relu to paddle.Tensor, remove this when fixed!\n", - "register user type_as to paddle.Tensor, remove this when fixed!\n", - "register user to to paddle.Tensor, remove this when fixed!\n", - "register user float to paddle.Tensor, remove this when fixed!\n", - "register user tolist to paddle.Tensor, remove this when fixed!\n", - "register user glu to paddle.nn.functional, remove this when fixed!\n", - "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "register user Module to paddle.nn, remove this when fixed!\n", - "register user ModuleList to paddle.nn, remove this when fixed!\n", - "register user GLU to paddle.nn, remove this when fixed!\n", - "register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "register user export to paddle.jit, remove this when fixed!\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "import paddle\n", - "from yacs.config import CfgNode as CN\n", - "\n", - "from deepspeech.models.u2 import U2Model\n", - "from deepspeech.utils.layer_tools import print_params\n", - "from deepspeech.utils.layer_tools import summary" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "permanent-summary", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "[INFO 2021/05/31 03:23:22 u2.py:839] U2 Encoder type: transformer\n", - "[INFO 2021/05/31 03:23:22 u2.py:840] attention_dropout_rate: 0.0\n", - "attention_heads: 4\n", - "dropout_rate: 0.1\n", - "input_layer: conv2d\n", - "linear_units: 2048\n", - "normalize_before: True\n", - "num_blocks: 12\n", - "output_size: 256\n", - "positional_dropout_rate: 0.1\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", - "encoder.embed.conv.0.bias | [256] | 256 | True\n", - "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", - "encoder.embed.conv.2.bias | [256] | 256 | True\n", - "encoder.embed.out.0.weight | [5120, 256] | 1310720 | True\n", - "encoder.embed.out.0.bias | [256] | 256 | True\n", - "encoder.after_norm.weight | [256] | 256 | True\n", - "encoder.after_norm.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n", - "decoder.embed.0.weight | [4233, 256] | 1083648 | True\n", - "decoder.after_norm.weight | [256] | 256 | True\n", - "decoder.after_norm.bias | [256] | 256 | True\n", - "decoder.output_layer.weight | [256, 4233] | 1083648 | True\n", - "decoder.output_layer.bias | [4233] | 4233 | True\n", - "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n", - "ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n", - "ctc.ctc_lo.bias | [4233] | 4233 | True\n", - "Total parameters: 411.0, 32.01M elements.\n" - ] - } - ], - "source": [ - "conf_str='examples/tiny/s1/conf/transformer.yaml'\n", - "cfg = CN().load_cfg(open(conf_str))\n", - "cfg.model.input_dim = 83\n", - "cfg.model.output_dim = 4233\n", - "cfg.model.cmvn_file = None\n", - "cfg.model.cmvn_file_type = 'json'\n", - "#cfg.model.encoder_conf.concat_after=True\n", - "cfg.freeze()\n", - "model = U2Model(cfg.model)\n", - "\n", - "print_params(model)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "sapphire-agent", - "metadata": {}, - "outputs": [], - "source": [ - "#summary(model)\n", - "#print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ruled-invitation", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "fossil-means", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n", - "encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n" - ] - } - ], - "source": [ - "%ls .notebook/espnet" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "45c2b75f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "state\n", - "odict_keys(['mask_feature', 'encoder.embed.conv.0.weight', 'encoder.embed.conv.0.bias', 'encoder.embed.conv.2.weight', 'encoder.embed.conv.2.bias', 'encoder.embed.out.0.weight', 'encoder.embed.out.0.bias', 'encoder.encoders.0.self_attn.linear_q.weight', 'encoder.encoders.0.self_attn.linear_q.bias', 'encoder.encoders.0.self_attn.linear_k.weight', 'encoder.encoders.0.self_attn.linear_k.bias', 'encoder.encoders.0.self_attn.linear_v.weight', 'encoder.encoders.0.self_attn.linear_v.bias', 'encoder.encoders.0.self_attn.linear_out.weight', 'encoder.encoders.0.self_attn.linear_out.bias', 'encoder.encoders.0.feed_forward.w_1.weight', 'encoder.encoders.0.feed_forward.w_1.bias', 'encoder.encoders.0.feed_forward.w_2.weight', 'encoder.encoders.0.feed_forward.w_2.bias', 'encoder.encoders.0.norm1.weight', 'encoder.encoders.0.norm1.bias', 'encoder.encoders.0.norm2.weight', 'encoder.encoders.0.norm2.bias', 'encoder.encoders.1.self_attn.linear_q.weight', 'encoder.encoders.1.self_attn.linear_q.bias', 'encoder.encoders.1.self_attn.linear_k.weight', 'encoder.encoders.1.self_attn.linear_k.bias', 'encoder.encoders.1.self_attn.linear_v.weight', 'encoder.encoders.1.self_attn.linear_v.bias', 'encoder.encoders.1.self_attn.linear_out.weight', 'encoder.encoders.1.self_attn.linear_out.bias', 'encoder.encoders.1.feed_forward.w_1.weight', 'encoder.encoders.1.feed_forward.w_1.bias', 'encoder.encoders.1.feed_forward.w_2.weight', 'encoder.encoders.1.feed_forward.w_2.bias', 'encoder.encoders.1.norm1.weight', 'encoder.encoders.1.norm1.bias', 'encoder.encoders.1.norm2.weight', 'encoder.encoders.1.norm2.bias', 'encoder.encoders.2.self_attn.linear_q.weight', 'encoder.encoders.2.self_attn.linear_q.bias', 'encoder.encoders.2.self_attn.linear_k.weight', 'encoder.encoders.2.self_attn.linear_k.bias', 'encoder.encoders.2.self_attn.linear_v.weight', 'encoder.encoders.2.self_attn.linear_v.bias', 'encoder.encoders.2.self_attn.linear_out.weight', 'encoder.encoders.2.self_attn.linear_out.bias', 'encoder.encoders.2.feed_forward.w_1.weight', 'encoder.encoders.2.feed_forward.w_1.bias', 'encoder.encoders.2.feed_forward.w_2.weight', 'encoder.encoders.2.feed_forward.w_2.bias', 'encoder.encoders.2.norm1.weight', 'encoder.encoders.2.norm1.bias', 'encoder.encoders.2.norm2.weight', 'encoder.encoders.2.norm2.bias', 'encoder.encoders.3.self_attn.linear_q.weight', 'encoder.encoders.3.self_attn.linear_q.bias', 'encoder.encoders.3.self_attn.linear_k.weight', 'encoder.encoders.3.self_attn.linear_k.bias', 'encoder.encoders.3.self_attn.linear_v.weight', 'encoder.encoders.3.self_attn.linear_v.bias', 'encoder.encoders.3.self_attn.linear_out.weight', 'encoder.encoders.3.self_attn.linear_out.bias', 'encoder.encoders.3.feed_forward.w_1.weight', 'encoder.encoders.3.feed_forward.w_1.bias', 'encoder.encoders.3.feed_forward.w_2.weight', 'encoder.encoders.3.feed_forward.w_2.bias', 'encoder.encoders.3.norm1.weight', 'encoder.encoders.3.norm1.bias', 'encoder.encoders.3.norm2.weight', 'encoder.encoders.3.norm2.bias', 'encoder.encoders.4.self_attn.linear_q.weight', 'encoder.encoders.4.self_attn.linear_q.bias', 'encoder.encoders.4.self_attn.linear_k.weight', 'encoder.encoders.4.self_attn.linear_k.bias', 'encoder.encoders.4.self_attn.linear_v.weight', 'encoder.encoders.4.self_attn.linear_v.bias', 'encoder.encoders.4.self_attn.linear_out.weight', 'encoder.encoders.4.self_attn.linear_out.bias', 'encoder.encoders.4.feed_forward.w_1.weight', 'encoder.encoders.4.feed_forward.w_1.bias', 'encoder.encoders.4.feed_forward.w_2.weight', 'encoder.encoders.4.feed_forward.w_2.bias', 'encoder.encoders.4.norm1.weight', 'encoder.encoders.4.norm1.bias', 'encoder.encoders.4.norm2.weight', 'encoder.encoders.4.norm2.bias', 'encoder.encoders.5.self_attn.linear_q.weight', 'encoder.encoders.5.self_attn.linear_q.bias', 'encoder.encoders.5.self_attn.linear_k.weight', 'encoder.encoders.5.self_attn.linear_k.bias', 'encoder.encoders.5.self_attn.linear_v.weight', 'encoder.encoders.5.self_attn.linear_v.bias', 'encoder.encoders.5.self_attn.linear_out.weight', 'encoder.encoders.5.self_attn.linear_out.bias', 'encoder.encoders.5.feed_forward.w_1.weight', 'encoder.encoders.5.feed_forward.w_1.bias', 'encoder.encoders.5.feed_forward.w_2.weight', 'encoder.encoders.5.feed_forward.w_2.bias', 'encoder.encoders.5.norm1.weight', 'encoder.encoders.5.norm1.bias', 'encoder.encoders.5.norm2.weight', 'encoder.encoders.5.norm2.bias', 'encoder.encoders.6.self_attn.linear_q.weight', 'encoder.encoders.6.self_attn.linear_q.bias', 'encoder.encoders.6.self_attn.linear_k.weight', 'encoder.encoders.6.self_attn.linear_k.bias', 'encoder.encoders.6.self_attn.linear_v.weight', 'encoder.encoders.6.self_attn.linear_v.bias', 'encoder.encoders.6.self_attn.linear_out.weight', 'encoder.encoders.6.self_attn.linear_out.bias', 'encoder.encoders.6.feed_forward.w_1.weight', 'encoder.encoders.6.feed_forward.w_1.bias', 'encoder.encoders.6.feed_forward.w_2.weight', 'encoder.encoders.6.feed_forward.w_2.bias', 'encoder.encoders.6.norm1.weight', 'encoder.encoders.6.norm1.bias', 'encoder.encoders.6.norm2.weight', 'encoder.encoders.6.norm2.bias', 'encoder.encoders.7.self_attn.linear_q.weight', 'encoder.encoders.7.self_attn.linear_q.bias', 'encoder.encoders.7.self_attn.linear_k.weight', 'encoder.encoders.7.self_attn.linear_k.bias', 'encoder.encoders.7.self_attn.linear_v.weight', 'encoder.encoders.7.self_attn.linear_v.bias', 'encoder.encoders.7.self_attn.linear_out.weight', 'encoder.encoders.7.self_attn.linear_out.bias', 'encoder.encoders.7.feed_forward.w_1.weight', 'encoder.encoders.7.feed_forward.w_1.bias', 'encoder.encoders.7.feed_forward.w_2.weight', 'encoder.encoders.7.feed_forward.w_2.bias', 'encoder.encoders.7.norm1.weight', 'encoder.encoders.7.norm1.bias', 'encoder.encoders.7.norm2.weight', 'encoder.encoders.7.norm2.bias', 'encoder.encoders.8.self_attn.linear_q.weight', 'encoder.encoders.8.self_attn.linear_q.bias', 'encoder.encoders.8.self_attn.linear_k.weight', 'encoder.encoders.8.self_attn.linear_k.bias', 'encoder.encoders.8.self_attn.linear_v.weight', 'encoder.encoders.8.self_attn.linear_v.bias', 'encoder.encoders.8.self_attn.linear_out.weight', 'encoder.encoders.8.self_attn.linear_out.bias', 'encoder.encoders.8.feed_forward.w_1.weight', 'encoder.encoders.8.feed_forward.w_1.bias', 'encoder.encoders.8.feed_forward.w_2.weight', 'encoder.encoders.8.feed_forward.w_2.bias', 'encoder.encoders.8.norm1.weight', 'encoder.encoders.8.norm1.bias', 'encoder.encoders.8.norm2.weight', 'encoder.encoders.8.norm2.bias', 'encoder.encoders.9.self_attn.linear_q.weight', 'encoder.encoders.9.self_attn.linear_q.bias', 'encoder.encoders.9.self_attn.linear_k.weight', 'encoder.encoders.9.self_attn.linear_k.bias', 'encoder.encoders.9.self_attn.linear_v.weight', 'encoder.encoders.9.self_attn.linear_v.bias', 'encoder.encoders.9.self_attn.linear_out.weight', 'encoder.encoders.9.self_attn.linear_out.bias', 'encoder.encoders.9.feed_forward.w_1.weight', 'encoder.encoders.9.feed_forward.w_1.bias', 'encoder.encoders.9.feed_forward.w_2.weight', 'encoder.encoders.9.feed_forward.w_2.bias', 'encoder.encoders.9.norm1.weight', 'encoder.encoders.9.norm1.bias', 'encoder.encoders.9.norm2.weight', 'encoder.encoders.9.norm2.bias', 'encoder.encoders.10.self_attn.linear_q.weight', 'encoder.encoders.10.self_attn.linear_q.bias', 'encoder.encoders.10.self_attn.linear_k.weight', 'encoder.encoders.10.self_attn.linear_k.bias', 'encoder.encoders.10.self_attn.linear_v.weight', 'encoder.encoders.10.self_attn.linear_v.bias', 'encoder.encoders.10.self_attn.linear_out.weight', 'encoder.encoders.10.self_attn.linear_out.bias', 'encoder.encoders.10.feed_forward.w_1.weight', 'encoder.encoders.10.feed_forward.w_1.bias', 'encoder.encoders.10.feed_forward.w_2.weight', 'encoder.encoders.10.feed_forward.w_2.bias', 'encoder.encoders.10.norm1.weight', 'encoder.encoders.10.norm1.bias', 'encoder.encoders.10.norm2.weight', 'encoder.encoders.10.norm2.bias', 'encoder.encoders.11.self_attn.linear_q.weight', 'encoder.encoders.11.self_attn.linear_q.bias', 'encoder.encoders.11.self_attn.linear_k.weight', 'encoder.encoders.11.self_attn.linear_k.bias', 'encoder.encoders.11.self_attn.linear_v.weight', 'encoder.encoders.11.self_attn.linear_v.bias', 'encoder.encoders.11.self_attn.linear_out.weight', 'encoder.encoders.11.self_attn.linear_out.bias', 'encoder.encoders.11.feed_forward.w_1.weight', 'encoder.encoders.11.feed_forward.w_1.bias', 'encoder.encoders.11.feed_forward.w_2.weight', 'encoder.encoders.11.feed_forward.w_2.bias', 'encoder.encoders.11.norm1.weight', 'encoder.encoders.11.norm1.bias', 'encoder.encoders.11.norm2.weight', 'encoder.encoders.11.norm2.bias', 'encoder.after_norm.weight', 'encoder.after_norm.bias', 'decoder.embed.0.weight', 'decoder.decoders.0.self_attn.linear_q.weight', 'decoder.decoders.0.self_attn.linear_q.bias', 'decoder.decoders.0.self_attn.linear_k.weight', 'decoder.decoders.0.self_attn.linear_k.bias', 'decoder.decoders.0.self_attn.linear_v.weight', 'decoder.decoders.0.self_attn.linear_v.bias', 'decoder.decoders.0.self_attn.linear_out.weight', 'decoder.decoders.0.self_attn.linear_out.bias', 'decoder.decoders.0.src_attn.linear_q.weight', 'decoder.decoders.0.src_attn.linear_q.bias', 'decoder.decoders.0.src_attn.linear_k.weight', 'decoder.decoders.0.src_attn.linear_k.bias', 'decoder.decoders.0.src_attn.linear_v.weight', 'decoder.decoders.0.src_attn.linear_v.bias', 'decoder.decoders.0.src_attn.linear_out.weight', 'decoder.decoders.0.src_attn.linear_out.bias', 'decoder.decoders.0.feed_forward.w_1.weight', 'decoder.decoders.0.feed_forward.w_1.bias', 'decoder.decoders.0.feed_forward.w_2.weight', 'decoder.decoders.0.feed_forward.w_2.bias', 'decoder.decoders.0.norm1.weight', 'decoder.decoders.0.norm1.bias', 'decoder.decoders.0.norm2.weight', 'decoder.decoders.0.norm2.bias', 'decoder.decoders.0.norm3.weight', 'decoder.decoders.0.norm3.bias', 'decoder.after_norm.weight', 'decoder.after_norm.bias', 'decoder.output_layer.weight', 'decoder.output_layer.bias', 'sfc.weight', 'sfc.bias', 'deconv.0.weight', 'deconv.0.bias', 'deconv.1.weight', 'deconv.1.bias', 'xlm_embed.0.weight', 'xlm_pred.weight', 'xlm_pred.bias'])\n" - ] - } - ], - "source": [ - "#!pip install torch\n", - "import torch\n", - "\n", - "e_model = np.load('.notebook/espnet/model.npz',allow_pickle=True)\n", - "for k in e_model.files:\n", - " print(k)\n", - "state_dict = e_model['state']\n", - "state_dict = state_dict.tolist()\n", - "print(state_dict.keys())" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f187bb55", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "# embed.conv.0.weight None torch.Size([256, 1, 3, 3]) \tencoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", - "# embed.conv.0.bias None torch.Size([256]) \tencoder.embed.conv.0.bias | [256] | 256 | True\n", - "# embed.conv.2.weight None torch.Size([256, 256, 3, 3]) \tencoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", - "# embed.conv.2.bias None torch.Size([256]) \tencoder.embed.conv.2.bias | [256] | 256 | True\n", - "# embed.out.0.weight None torch.Size([256, 5120]) 83 feature\tencoder.embed.out.0.weight | [4864, 256] | 1245184 | True 80 feature\n", - "# embed.out.0.bias None torch.Size([256]) \tencoder.embed.out.0.bias | [256] | 256 | True\n", - "# after_norm.weight None torch.Size([256]) \tencoder.after_norm.weight | [256] | 256 | True\n", - "# after_norm.bias None torch.Size([256]) \tencoder.after_norm.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_q.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_q.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_k.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_k.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_v.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_v.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_out.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_out.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "# encoders.9.feed_forward.w_1.weight None torch.Size([2048, 256]) \tencoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "# encoders.9.feed_forward.w_1.bias None torch.Size([2048]) \tencoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "# encoders.9.feed_forward.w_2.weight None torch.Size([256, 2048]) \tencoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "# encoders.9.feed_forward.w_2.bias None torch.Size([256]) \tencoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "# encoders.9.norm1.weight None torch.Size([256]) \tencoder.encoders.0.norm1.weight | [256] | 256 | True\n", - "# encoders.9.norm1.bias None torch.Size([256]) \tencoder.encoders.0.norm1.bias | [256] | 256 | True\n", - "# encoders.9.norm2.weight None torch.Size([256]) \tencoder.encoders.0.norm2.weight | [256] | 256 | True\n", - "# encoders.9.norm2.bias None torch.Size([256]) \tencoder.encoders.0.norm2.bias | [256] | 256 | True\n", - "# \tencoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", - "# \tencoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", - "# espnet transformer\tconcat_linear只是保存了,但是未使用\n", - "\t\n", - "# \tpaddle transformer" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "2a0428ae", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-> encoder.embed.conv.0.weight\n", - "-> encoder.embed.conv.0.bias\n", - "-> encoder.embed.conv.2.weight\n", - "-> encoder.embed.conv.2.bias\n", - "-> encoder.embed.out.0.weight\n", - "encoder.embed.out.0.weight: (256, 5120) -> (5120, 256)\n", - "-> encoder.embed.out.0.bias\n", - "-> encoder.encoders.0.self_attn.linear_q.weight\n", - "encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_q.bias\n", - "-> encoder.encoders.0.self_attn.linear_k.weight\n", - "encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_k.bias\n", - "-> encoder.encoders.0.self_attn.linear_v.weight\n", - "encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_v.bias\n", - "-> encoder.encoders.0.self_attn.linear_out.weight\n", - "encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_out.bias\n", - "-> encoder.encoders.0.feed_forward.w_1.weight\n", - "encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.0.feed_forward.w_1.bias\n", - "-> encoder.encoders.0.feed_forward.w_2.weight\n", - "encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.0.feed_forward.w_2.bias\n", - "-> encoder.encoders.0.norm1.weight\n", - "-> encoder.encoders.0.norm1.bias\n", - "-> encoder.encoders.0.norm2.weight\n", - "-> encoder.encoders.0.norm2.bias\n", - "-> encoder.encoders.1.self_attn.linear_q.weight\n", - "encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_q.bias\n", - "-> encoder.encoders.1.self_attn.linear_k.weight\n", - "encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_k.bias\n", - "-> encoder.encoders.1.self_attn.linear_v.weight\n", - "encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_v.bias\n", - "-> encoder.encoders.1.self_attn.linear_out.weight\n", - "encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_out.bias\n", - "-> encoder.encoders.1.feed_forward.w_1.weight\n", - "encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.1.feed_forward.w_1.bias\n", - "-> encoder.encoders.1.feed_forward.w_2.weight\n", - "encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.1.feed_forward.w_2.bias\n", - "-> encoder.encoders.1.norm1.weight\n", - "-> encoder.encoders.1.norm1.bias\n", - "-> encoder.encoders.1.norm2.weight\n", - "-> encoder.encoders.1.norm2.bias\n", - "-> encoder.encoders.2.self_attn.linear_q.weight\n", - "encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_q.bias\n", - "-> encoder.encoders.2.self_attn.linear_k.weight\n", - "encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_k.bias\n", - "-> encoder.encoders.2.self_attn.linear_v.weight\n", - "encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_v.bias\n", - "-> encoder.encoders.2.self_attn.linear_out.weight\n", - "encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_out.bias\n", - "-> encoder.encoders.2.feed_forward.w_1.weight\n", - "encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.2.feed_forward.w_1.bias\n", - "-> encoder.encoders.2.feed_forward.w_2.weight\n", - "encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.2.feed_forward.w_2.bias\n", - "-> encoder.encoders.2.norm1.weight\n", - "-> encoder.encoders.2.norm1.bias\n", - "-> encoder.encoders.2.norm2.weight\n", - "-> encoder.encoders.2.norm2.bias\n", - "-> encoder.encoders.3.self_attn.linear_q.weight\n", - "encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_q.bias\n", - "-> encoder.encoders.3.self_attn.linear_k.weight\n", - "encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_k.bias\n", - "-> encoder.encoders.3.self_attn.linear_v.weight\n", - "encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_v.bias\n", - "-> encoder.encoders.3.self_attn.linear_out.weight\n", - "encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_out.bias\n", - "-> encoder.encoders.3.feed_forward.w_1.weight\n", - "encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.3.feed_forward.w_1.bias\n", - "-> encoder.encoders.3.feed_forward.w_2.weight\n", - "encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.3.feed_forward.w_2.bias\n", - "-> encoder.encoders.3.norm1.weight\n", - "-> encoder.encoders.3.norm1.bias\n", - "-> encoder.encoders.3.norm2.weight\n", - "-> encoder.encoders.3.norm2.bias\n", - "-> encoder.encoders.4.self_attn.linear_q.weight\n", - "encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_q.bias\n", - "-> encoder.encoders.4.self_attn.linear_k.weight\n", - "encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_k.bias\n", - "-> encoder.encoders.4.self_attn.linear_v.weight\n", - "encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_v.bias\n", - "-> encoder.encoders.4.self_attn.linear_out.weight\n", - "encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_out.bias\n", - "-> encoder.encoders.4.feed_forward.w_1.weight\n", - "encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.4.feed_forward.w_1.bias\n", - "-> encoder.encoders.4.feed_forward.w_2.weight\n", - "encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.4.feed_forward.w_2.bias\n", - "-> encoder.encoders.4.norm1.weight\n", - "-> encoder.encoders.4.norm1.bias\n", - "-> encoder.encoders.4.norm2.weight\n", - "-> encoder.encoders.4.norm2.bias\n", - "-> encoder.encoders.5.self_attn.linear_q.weight\n", - "encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_q.bias\n", - "-> encoder.encoders.5.self_attn.linear_k.weight\n", - "encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_k.bias\n", - "-> encoder.encoders.5.self_attn.linear_v.weight\n", - "encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_v.bias\n", - "-> encoder.encoders.5.self_attn.linear_out.weight\n", - "encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_out.bias\n", - "-> encoder.encoders.5.feed_forward.w_1.weight\n", - "encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.5.feed_forward.w_1.bias\n", - "-> encoder.encoders.5.feed_forward.w_2.weight\n", - "encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.5.feed_forward.w_2.bias\n", - "-> encoder.encoders.5.norm1.weight\n", - "-> encoder.encoders.5.norm1.bias\n", - "-> encoder.encoders.5.norm2.weight\n", - "-> encoder.encoders.5.norm2.bias\n", - "-> encoder.encoders.6.self_attn.linear_q.weight\n", - "encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_q.bias\n", - "-> encoder.encoders.6.self_attn.linear_k.weight\n", - "encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_k.bias\n", - "-> encoder.encoders.6.self_attn.linear_v.weight\n", - "encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_v.bias\n", - "-> encoder.encoders.6.self_attn.linear_out.weight\n", - "encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_out.bias\n", - "-> encoder.encoders.6.feed_forward.w_1.weight\n", - "encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.6.feed_forward.w_1.bias\n", - "-> encoder.encoders.6.feed_forward.w_2.weight\n", - "encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.6.feed_forward.w_2.bias\n", - "-> encoder.encoders.6.norm1.weight\n", - "-> encoder.encoders.6.norm1.bias\n", - "-> encoder.encoders.6.norm2.weight\n", - "-> encoder.encoders.6.norm2.bias\n", - "-> encoder.encoders.7.self_attn.linear_q.weight\n", - "encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_q.bias\n", - "-> encoder.encoders.7.self_attn.linear_k.weight\n", - "encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_k.bias\n", - "-> encoder.encoders.7.self_attn.linear_v.weight\n", - "encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_v.bias\n", - "-> encoder.encoders.7.self_attn.linear_out.weight\n", - "encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_out.bias\n", - "-> encoder.encoders.7.feed_forward.w_1.weight\n", - "encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.7.feed_forward.w_1.bias\n", - "-> encoder.encoders.7.feed_forward.w_2.weight\n", - "encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.7.feed_forward.w_2.bias\n", - "-> encoder.encoders.7.norm1.weight\n", - "-> encoder.encoders.7.norm1.bias\n", - "-> encoder.encoders.7.norm2.weight\n", - "-> encoder.encoders.7.norm2.bias\n", - "-> encoder.encoders.8.self_attn.linear_q.weight\n", - "encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_q.bias\n", - "-> encoder.encoders.8.self_attn.linear_k.weight\n", - "encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_k.bias\n", - "-> encoder.encoders.8.self_attn.linear_v.weight\n", - "encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_v.bias\n", - "-> encoder.encoders.8.self_attn.linear_out.weight\n", - "encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_out.bias\n", - "-> encoder.encoders.8.feed_forward.w_1.weight\n", - "encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.8.feed_forward.w_1.bias\n", - "-> encoder.encoders.8.feed_forward.w_2.weight\n", - "encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.8.feed_forward.w_2.bias\n", - "-> encoder.encoders.8.norm1.weight\n", - "-> encoder.encoders.8.norm1.bias\n", - "-> encoder.encoders.8.norm2.weight\n", - "-> encoder.encoders.8.norm2.bias\n", - "-> encoder.encoders.9.self_attn.linear_q.weight\n", - "encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_q.bias\n", - "-> encoder.encoders.9.self_attn.linear_k.weight\n", - "encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_k.bias\n", - "-> encoder.encoders.9.self_attn.linear_v.weight\n", - "encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_v.bias\n", - "-> encoder.encoders.9.self_attn.linear_out.weight\n", - "encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_out.bias\n", - "-> encoder.encoders.9.feed_forward.w_1.weight\n", - "encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.9.feed_forward.w_1.bias\n", - "-> encoder.encoders.9.feed_forward.w_2.weight\n", - "encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.9.feed_forward.w_2.bias\n", - "-> encoder.encoders.9.norm1.weight\n", - "-> encoder.encoders.9.norm1.bias\n", - "-> encoder.encoders.9.norm2.weight\n", - "-> encoder.encoders.9.norm2.bias\n", - "-> encoder.encoders.10.self_attn.linear_q.weight\n", - "encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_q.bias\n", - "-> encoder.encoders.10.self_attn.linear_k.weight\n", - "encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_k.bias\n", - "-> encoder.encoders.10.self_attn.linear_v.weight\n", - "encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_v.bias\n", - "-> encoder.encoders.10.self_attn.linear_out.weight\n", - "encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_out.bias\n", - "-> encoder.encoders.10.feed_forward.w_1.weight\n", - "encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.10.feed_forward.w_1.bias\n", - "-> encoder.encoders.10.feed_forward.w_2.weight\n", - "encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.10.feed_forward.w_2.bias\n", - "-> encoder.encoders.10.norm1.weight\n", - "-> encoder.encoders.10.norm1.bias\n", - "-> encoder.encoders.10.norm2.weight\n", - "-> encoder.encoders.10.norm2.bias\n", - "-> encoder.encoders.11.self_attn.linear_q.weight\n", - "encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_q.bias\n", - "-> encoder.encoders.11.self_attn.linear_k.weight\n", - "encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_k.bias\n", - "-> encoder.encoders.11.self_attn.linear_v.weight\n", - "encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_v.bias\n", - "-> encoder.encoders.11.self_attn.linear_out.weight\n", - "encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_out.bias\n", - "-> encoder.encoders.11.feed_forward.w_1.weight\n", - "encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.11.feed_forward.w_1.bias\n", - "-> encoder.encoders.11.feed_forward.w_2.weight\n", - "encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.11.feed_forward.w_2.bias\n", - "-> encoder.encoders.11.norm1.weight\n", - "-> encoder.encoders.11.norm1.bias\n", - "-> encoder.encoders.11.norm2.weight\n", - "-> encoder.encoders.11.norm2.bias\n", - "-> encoder.after_norm.weight\n", - "-> encoder.after_norm.bias\n" - ] - } - ], - "source": [ - "# dump torch model to paddle\n", - "#state_dict = model.state_dict()\n", - "paddle_state_dict = {}\n", - "\n", - "for n, p in state_dict.items():\n", - " if 'encoder' not in n:\n", - " continue \n", - " print(f'-> {n}')\n", - " \n", - " \n", - " name_change=True\n", - " if 'norm.running_mean' in n:\n", - " new_n = n.replace('norm.running_', 'norm._')\n", - " elif 'norm.running_var' in n:\n", - " new_n = n.replace('norm.running_var', 'norm._variance')\n", - " else:\n", - " name_change=False\n", - " new_n = n\n", - " if name_change:\n", - " print(f\"{n} -> {new_n}\")\n", - " \n", - " \n", - " p = p.cpu().detach().numpy()\n", - " if n.endswith('weight') and p.ndim == 2:\n", - " new_p = p.T\n", - " print(f\"{n}: {p.shape} -> {new_p.shape}\")\n", - " else:\n", - " new_p = p\n", - " \n", - " if 'global_cmvn.mean' in n:\n", - " print(p, p.dtype)\n", - " \n", - " paddle_state_dict[new_n] = new_p\n", - " \n", - "# np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n", - "# state=paddle_state_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "a1d97e9f", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.weight. encoder.encoders.0.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.bias. encoder.encoders.0.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.weight. encoder.encoders.1.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.bias. encoder.encoders.1.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.weight. encoder.encoders.2.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.bias. encoder.encoders.2.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.weight. encoder.encoders.3.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.bias. encoder.encoders.3.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.weight. encoder.encoders.4.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.bias. encoder.encoders.4.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.weight. encoder.encoders.5.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.bias. encoder.encoders.5.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.weight. encoder.encoders.6.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.bias. encoder.encoders.6.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.weight. encoder.encoders.7.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.bias. encoder.encoders.7.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.weight. encoder.encoders.8.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.bias. encoder.encoders.8.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.weight. encoder.encoders.9.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.bias. encoder.encoders.9.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.weight. encoder.encoders.10.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.bias. encoder.encoders.10.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.weight. encoder.encoders.11.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.bias. encoder.encoders.11.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.embed.0.weight. decoder.embed.0.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.weight. decoder.after_norm.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.bias. decoder.after_norm.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.weight. decoder.output_layer.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.bias. decoder.output_layer.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.weight. decoder.decoders.0.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.bias. decoder.decoders.0.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.weight. decoder.decoders.0.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.bias. decoder.decoders.0.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.weight. decoder.decoders.0.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.bias. decoder.decoders.0.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.weight. decoder.decoders.0.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.bias. decoder.decoders.0.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.weight. decoder.decoders.0.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.bias. decoder.decoders.0.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.weight. decoder.decoders.0.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.bias. decoder.decoders.0.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.weight. decoder.decoders.0.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.bias. decoder.decoders.0.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.weight. decoder.decoders.0.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.bias. decoder.decoders.0.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.weight. decoder.decoders.0.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.bias. decoder.decoders.0.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.weight. decoder.decoders.0.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.bias. decoder.decoders.0.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.weight. decoder.decoders.0.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.bias. decoder.decoders.0.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.weight. decoder.decoders.0.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.bias. decoder.decoders.0.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.weight. decoder.decoders.0.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.bias. decoder.decoders.0.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.weight. decoder.decoders.0.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.bias. decoder.decoders.0.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.weight. decoder.decoders.0.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.bias. decoder.decoders.0.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.weight. decoder.decoders.1.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.bias. decoder.decoders.1.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.weight. decoder.decoders.1.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.bias. decoder.decoders.1.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.weight. decoder.decoders.1.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.bias. decoder.decoders.1.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.weight. decoder.decoders.1.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.bias. decoder.decoders.1.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.weight. decoder.decoders.1.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.bias. decoder.decoders.1.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.weight. decoder.decoders.1.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.bias. decoder.decoders.1.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.weight. decoder.decoders.1.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.bias. decoder.decoders.1.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.weight. decoder.decoders.1.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.bias. decoder.decoders.1.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.weight. decoder.decoders.1.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.bias. decoder.decoders.1.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.weight. decoder.decoders.1.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.bias. decoder.decoders.1.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.weight. decoder.decoders.1.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.bias. decoder.decoders.1.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.weight. decoder.decoders.1.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.bias. decoder.decoders.1.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.weight. decoder.decoders.1.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.bias. decoder.decoders.1.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.weight. decoder.decoders.1.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.bias. decoder.decoders.1.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.weight. decoder.decoders.1.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.bias. decoder.decoders.1.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.weight. decoder.decoders.2.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.bias. decoder.decoders.2.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.weight. decoder.decoders.2.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.bias. decoder.decoders.2.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.weight. decoder.decoders.2.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.bias. decoder.decoders.2.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.weight. decoder.decoders.2.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.bias. decoder.decoders.2.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.weight. decoder.decoders.2.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.bias. decoder.decoders.2.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.weight. decoder.decoders.2.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.bias. decoder.decoders.2.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.weight. decoder.decoders.2.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.bias. decoder.decoders.2.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.weight. decoder.decoders.2.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.bias. decoder.decoders.2.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.weight. decoder.decoders.2.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.bias. decoder.decoders.2.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.weight. decoder.decoders.2.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.bias. decoder.decoders.2.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.weight. decoder.decoders.2.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.bias. decoder.decoders.2.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.weight. decoder.decoders.2.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.bias. decoder.decoders.2.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.weight. decoder.decoders.2.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.bias. decoder.decoders.2.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.weight. decoder.decoders.2.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.bias. decoder.decoders.2.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.weight. decoder.decoders.2.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.bias. decoder.decoders.2.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.weight. decoder.decoders.3.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.bias. decoder.decoders.3.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.weight. decoder.decoders.3.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.bias. decoder.decoders.3.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.weight. decoder.decoders.3.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.bias. decoder.decoders.3.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.weight. decoder.decoders.3.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.bias. decoder.decoders.3.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.weight. decoder.decoders.3.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.bias. decoder.decoders.3.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.weight. decoder.decoders.3.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.bias. decoder.decoders.3.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.weight. decoder.decoders.3.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.bias. decoder.decoders.3.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.weight. decoder.decoders.3.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.bias. decoder.decoders.3.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.weight. decoder.decoders.3.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.bias. decoder.decoders.3.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.weight. decoder.decoders.3.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.bias. decoder.decoders.3.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.weight. decoder.decoders.3.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.bias. decoder.decoders.3.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.weight. decoder.decoders.3.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.bias. decoder.decoders.3.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.weight. decoder.decoders.3.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.bias. decoder.decoders.3.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.weight. decoder.decoders.3.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.bias. decoder.decoders.3.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.weight. decoder.decoders.3.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.bias. decoder.decoders.3.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.weight. decoder.decoders.4.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.bias. decoder.decoders.4.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.weight. decoder.decoders.4.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.bias. decoder.decoders.4.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.weight. decoder.decoders.4.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.bias. decoder.decoders.4.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.weight. decoder.decoders.4.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.bias. decoder.decoders.4.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.weight. decoder.decoders.4.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.bias. decoder.decoders.4.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.weight. decoder.decoders.4.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.bias. decoder.decoders.4.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.weight. decoder.decoders.4.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.bias. decoder.decoders.4.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.weight. decoder.decoders.4.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.bias. decoder.decoders.4.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.weight. decoder.decoders.4.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.bias. decoder.decoders.4.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.weight. decoder.decoders.4.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.bias. decoder.decoders.4.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.weight. decoder.decoders.4.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.bias. decoder.decoders.4.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.weight. decoder.decoders.4.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.bias. decoder.decoders.4.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.weight. decoder.decoders.4.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.bias. decoder.decoders.4.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.weight. decoder.decoders.4.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.bias. decoder.decoders.4.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.weight. decoder.decoders.4.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.bias. decoder.decoders.4.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.weight. decoder.decoders.5.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.bias. decoder.decoders.5.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.weight. decoder.decoders.5.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.bias. decoder.decoders.5.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.weight. decoder.decoders.5.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.bias. decoder.decoders.5.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.weight. decoder.decoders.5.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.bias. decoder.decoders.5.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.weight. decoder.decoders.5.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.bias. decoder.decoders.5.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.weight. decoder.decoders.5.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.bias. decoder.decoders.5.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.weight. decoder.decoders.5.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.bias. decoder.decoders.5.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.weight. decoder.decoders.5.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.bias. decoder.decoders.5.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.weight. decoder.decoders.5.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.bias. decoder.decoders.5.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.weight. decoder.decoders.5.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.bias. decoder.decoders.5.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.weight. decoder.decoders.5.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.bias. decoder.decoders.5.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.weight. decoder.decoders.5.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.bias. decoder.decoders.5.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.weight. decoder.decoders.5.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.bias. decoder.decoders.5.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.weight. decoder.decoders.5.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.bias. decoder.decoders.5.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.weight. decoder.decoders.5.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.bias. decoder.decoders.5.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.weight. ctc.ctc_lo.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.bias. ctc.ctc_lo.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n" - ] - } - ], - "source": [ - "model.set_state_dict(paddle_state_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "fc7edf1e", - "metadata": {}, - "outputs": [], - "source": [ - "e_state = model.encoder.state_dict()\n", - "for key, value in e_state.items():\n", - " if 'concat_linear' in key:\n", - " continue\n", - " if not np.allclose(value.numpy(), paddle_state_dict['encoder.' + key]):\n", - " print(key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "572097d0", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "748250b7", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "91e5deee", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "fleet-despite", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n", - "encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n" - ] - } - ], - "source": [ - "%ls .notebook/espnet" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "abroad-oracle", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(8, 57, 83)\n", - "(8, 1, 57)\n", - "[57 50 48 38 32 31 28 25]\n" - ] - } - ], - "source": [ - "data = np.load('.notebook/espnet/feat.npz', allow_pickle=True)\n", - "xs=data['xs']\n", - "masks=data['masks']\n", - "print(xs.shape)\n", - "print(masks.shape)\n", - "xs_lens = masks.sum(axis=-1).squeeze()\n", - "print(xs_lens)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "false-instrument", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[8, 13, 256]\n", - "[8, 1, 13]\n" - ] - } - ], - "source": [ - "# ecnoder\n", - "xs = paddle.to_tensor(xs, dtype='float32')\n", - "x_lens = paddle.to_tensor(xs_lens, dtype='int32')\n", - "model.eval()\n", - "encoder_out, encoder_mask = model.encoder(xs, x_lens)\n", - "print(encoder_out.shape)\n", - "print(encoder_mask.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "arctic-proxy", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(8, 13, 256)\n", - "(8, 1, 13)\n", - "False\n", - "False\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "data = np.load('.notebook/espnet/encoder.npz', allow_pickle=True)\n", - "xs = data['xs']\n", - "masks = data['masks']\n", - "print(xs.shape)\n", - "print(masks.shape)\n", - "print(np.allclose(xs, encoder_out.numpy()))\n", - "print(np.allclose(xs, encoder_out.numpy(), atol=1e-6))\n", - "print(np.allclose(xs, encoder_out.numpy(), atol=1e-5))\n", - "print(np.allclose(masks, encoder_mask.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "seasonal-switch", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 2.1380312 1.8675405 -1.1873871 ... -0.30456656 0.56382173\n", - " -0.6526459 ]\n", - " [ 2.1926146 2.1373641 -0.6548196 ... -0.897318 0.6044322\n", - " -0.63332295]\n", - " [ 1.6367635 2.3320658 -0.8848577 ... -0.9640939 1.2420733\n", - " -0.05243584]\n", - " ...\n", - " [ 1.8533031 1.8421621 -0.6728406 ... 0.04810616 0.6459763\n", - " -0.18188554]\n", - " [ 2.0894065 1.7813934 -1.1591585 ... -0.09513803 0.8321831\n", - " -0.72916794]\n", - " [ 1.6488649 2.0984242 -1.3490562 ... 0.42678255 0.5903866\n", - " -0.32597935]]\n", - "Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[ 2.13803196, 1.86753929, -1.18738675, ..., -0.30456796, 0.56382364, -0.65264463],\n", - " [ 2.19261336, 2.13736486, -0.65482187, ..., -0.89731705, 0.60443199, -0.63332343],\n", - " [ 1.63676369, 2.33206534, -0.88485885, ..., -0.96409231, 1.24207270, -0.05243752],\n", - " ...,\n", - " [ 1.85330284, 1.84216177, -0.67284071, ..., 0.04810715, 0.64597648, -0.18188696],\n", - " [ 2.08940673, 1.78139246, -1.15916038, ..., -0.09513779, 0.83218288, -0.72916913],\n", - " [ 1.64886570, 2.09842515, -1.34905660, ..., 0.42678308, 0.59038705, -0.32598034]])\n" - ] - } - ], - "source": [ - "print(xs[0])\n", - "print(encoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "defined-brooks", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 2.209824 1.5208759 0.1417884 ... -0.73617566 1.6538682\n", - " -0.16355833]\n", - " [ 2.1441019 1.4377339 0.3629197 ... -0.91226125 1.3739952\n", - " 0.11874156]\n", - " [ 1.8725398 1.5417286 0.38919652 ... -0.89621615 1.1841662\n", - " 0.27621832]\n", - " ...\n", - " [ 2.4591084 0.7238764 -1.1456345 ... -0.24188249 0.8232168\n", - " -0.9794884 ]\n", - " [ 2.5156236 1.1919155 -0.97032744 ... -0.7360675 1.0647209\n", - " -1.3076135 ]\n", - " [ 2.160009 0.98425585 -1.2231126 ... -0.03393313 1.9141548\n", - " -1.0099151 ]]\n", - "Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[ 2.20982409, 1.52087593, 0.14178854, ..., -0.73617446, 1.65386844, -0.16355731],\n", - " [ 2.14410043, 1.43773460, 0.36291891, ..., -0.91226172, 1.37399518, 0.11874183],\n", - " [ 1.87254059, 1.54172909, 0.38919681, ..., -0.89621687, 1.18416822, 0.27621880],\n", - " ...,\n", - " [ 2.45910931, 0.72387671, -1.14563596, ..., -0.24188218, 0.82321703, -0.97948682],\n", - " [ 2.51562238, 1.19191694, -0.97032893, ..., -0.73606837, 1.06472087, -1.30761123],\n", - " [ 2.16000915, 0.98425680, -1.22311163, ..., -0.03393326, 1.91415381, -1.00991392]])\n" - ] - } - ], - "source": [ - "print(xs[1])\n", - "print(encoder_out[1])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0504e3f8", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/wenet_model.ipynb b/.notebook/wenet_model.ipynb deleted file mode 100644 index 8e10b6c4b..000000000 --- a/.notebook/wenet_model.ipynb +++ /dev/null @@ -1,5015 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "cfb832c0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/wenet\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/wenet'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd /workspace/wenet/\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "62277538", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "import argparse\n", - "import copy\n", - "import logging\n", - "import os\n", - "\n", - "import torch\n", - "import torch.distributed as dist\n", - "import torch.optim as optim\n", - "import yaml\n", - "from tensorboardX import SummaryWriter\n", - "from torch.utils.data import DataLoader\n", - "\n", - "from wenet.dataset.dataset import AudioDataset, CollateFunc\n", - "from wenet.transformer.asr_model import init_asr_model\n", - "from wenet.utils.checkpoint import load_checkpoint, save_checkpoint\n", - "from wenet.utils.executor import Executor\n", - "from wenet.utils.scheduler import WarmupLR\n", - "\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = \"0\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "2f6ea33a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'config': 'examples/aishell/s0/conf/train_conformer.yaml', 'train_data': 'examples/aishell/s0/raw_wav/train/format.data', 'cv_data': 'examples/aishell/s0/raw_wav/dev/format.data', 'gpu': -1, 'model_dir': None, 'checkpoint': None, 'tensorboard_dir': 'tensorboard', 'rank': 0, 'world_size': -1, 'dist_backend': 'nccl', 'init_method': None, 'num_workers': 0, 'pin_memory': False, 'cmvn': 'examples/aishell/s0/raw_wav/train/global_cmvn'}\n" - ] - } - ], - "source": [ - "parser = argparse.ArgumentParser(description='training your network')\n", - "parser.add_argument('--config', default=\"examples/aishell/s0/conf/train_conformer.yaml\", help='config file')\n", - "parser.add_argument('--train_data', default=\"examples/aishell/s0/raw_wav/train/format.data\", help='train data file')\n", - "parser.add_argument('--cv_data', default=\"examples/aishell/s0/raw_wav/dev/format.data\", help='cv data file')\n", - "parser.add_argument('--gpu',\n", - " type=int,\n", - " default=-1,\n", - " help='gpu id for this local rank, -1 for cpu')\n", - "parser.add_argument('--model_dir' , help='save model dir')\n", - "parser.add_argument('--checkpoint', help='checkpoint model')\n", - "parser.add_argument('--tensorboard_dir',\n", - " default='tensorboard',\n", - " help='tensorboard log dir')\n", - "parser.add_argument('--ddp.rank',\n", - " dest='rank',\n", - " default=0,\n", - " type=int,\n", - " help='global rank for distributed training')\n", - "parser.add_argument('--ddp.world_size',\n", - " dest='world_size',\n", - " default=-1,\n", - " type=int,\n", - " help='''number of total processes/gpus for\n", - " distributed training''')\n", - "parser.add_argument('--ddp.dist_backend',\n", - " dest='dist_backend',\n", - " default='nccl',\n", - " choices=['nccl', 'gloo'],\n", - " help='distributed backend')\n", - "parser.add_argument('--ddp.init_method',\n", - " dest='init_method',\n", - " default=None,\n", - " help='ddp init method')\n", - "parser.add_argument('--num_workers',\n", - " default=0,\n", - " type=int,\n", - " help='num of subprocess workers for reading')\n", - "parser.add_argument('--pin_memory',\n", - " action='store_true',\n", - " default=False,\n", - " help='Use pinned memory buffers used for reading')\n", - "parser.add_argument('--cmvn', default=\"examples/aishell/s0/raw_wav/train/global_cmvn\", help='global cmvn file')\n", - "\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "f5d6af9b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Namespace(checkpoint=None, cmvn='examples/aishell/s0/raw_wav/train/global_cmvn', config='examples/aishell/s0/conf/train_conformer.yaml', cv_data='examples/aishell/s0/raw_wav/dev/format.data', dist_backend='nccl', gpu=-1, init_method=None, model_dir=None, num_workers=0, pin_memory=False, rank=0, tensorboard_dir='tensorboard', train_data='examples/aishell/s0/raw_wav/train/format.data', world_size=-1)\n" - ] - } - ], - "source": [ - "# Set random seed\n", - "torch.manual_seed(777)\n", - "print(args)\n", - "with open(args.config, 'r') as fin:\n", - " configs = yaml.load(fin, Loader=yaml.FullLoader)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "264bd353", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "7507 batches\n", - "896\n" - ] - } - ], - "source": [ - "raw_wav = configs['raw_wav']\n", - "\n", - "train_collate_func = CollateFunc(**configs['collate_conf'],\n", - " raw_wav=raw_wav)\n", - "\n", - "cv_collate_conf = copy.deepcopy(configs['collate_conf'])\n", - "# no augmenation on cv set\n", - "cv_collate_conf['spec_aug'] = False\n", - "cv_collate_conf['spec_sub'] = False\n", - "if raw_wav:\n", - " cv_collate_conf['feature_dither'] = 0.0\n", - " cv_collate_conf['speed_perturb'] = False\n", - " cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0\n", - "cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav)\n", - "\n", - "dataset_conf = configs.get('dataset_conf', {})\n", - "train_dataset = AudioDataset(args.train_data,\n", - " **dataset_conf,\n", - " raw_wav=raw_wav)\n", - "cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav)\n", - "# 120098 data/train/wav.scp\n", - "print(len(train_dataset), 'batches')\n", - "# 14326 data/dev/wav.scp\n", - "print(len(cv_dataset))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "88863d3c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "896\n" - ] - } - ], - "source": [ - "train_sampler = None\n", - "cv_sampler = None\n", - "train_data_loader = DataLoader(train_dataset,\n", - " collate_fn=train_collate_func,\n", - " sampler=train_sampler,\n", - " #shuffle=(train_sampler is None),\n", - " shuffle=False,\n", - " pin_memory=args.pin_memory,\n", - " batch_size=1,\n", - " num_workers=args.num_workers)\n", - "cv_data_loader = DataLoader(cv_dataset,\n", - " collate_fn=cv_collate_func,\n", - " sampler=cv_sampler,\n", - " shuffle=False,\n", - " batch_size=1,\n", - " pin_memory=args.pin_memory,\n", - " num_workers=args.num_workers)\n", - "print(len(cv_data_loader))" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "10d5acd4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "4233 vocab\n", - "80 feat dim\n" - ] - } - ], - "source": [ - "if raw_wav:\n", - " input_dim = configs['collate_conf']['feature_extraction_conf'][\n", - " 'mel_bins']\n", - "else:\n", - " input_dim = train_dataset.input_dim\n", - "vocab_size = train_dataset.output_dim\n", - "print(vocab_size, 'vocab')\n", - "print(input_dim , 'feat dim')" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "0380ef5a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "examples/aishell/s0/raw_wav/train/global_cmvn\n" - ] - } - ], - "source": [ - "# Save configs to model_dir/train.yaml for inference and export\n", - "configs['input_dim'] = input_dim\n", - "configs['output_dim'] = vocab_size\n", - "configs['cmvn_file'] = args.cmvn\n", - "configs['is_json_cmvn'] = raw_wav\n", - "print(args.cmvn)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "15ebf2bf", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(80,)\n", - "(80,)\n", - "[ 9.87176362 9.93891555 10.23818678 10.85971412 11.68652649 12.2548801\n", - " 12.65768161 12.86138996 12.80733912 12.56625574 12.32007066 12.13879205\n", - " 12.31318868 12.55255216 12.61223855 12.56974526 12.38972728 12.14383338\n", - " 12.09285066 11.79395822 11.62259065 11.9263303 11.8154422 11.95122567\n", - " 11.83180553 11.88788759 11.79014437 11.88072035 11.90005711 11.97348142\n", - " 12.00982189 12.00881339 12.02619706 12.10479646 12.21555081 12.34399304\n", - " 12.45014401 12.4966879 12.48653775 12.3550783 12.39291732 12.2553737\n", - " 12.26496277 12.25314244 12.32545763 12.43359839 12.54867439 12.6763342\n", - " 12.80920698 12.92934681 12.96115138 12.96883353 12.99593057 13.04728142\n", - " 13.0588804 13.05737948 12.99921175 12.93402238 12.87429219 12.71652995\n", - " 12.48942004 12.27478385 12.26163069 12.28631891 12.31956049 12.4229073\n", - " 12.51480191 12.5785164 12.64719411 12.73762568 12.80017069 12.86872766\n", - " 12.96666856 13.06478583 13.15915908 13.27284306 13.31081821 13.23904279\n", - " 12.87936075 11.18310185]\n", - "[0.61219383 0.49700994 0.33439025 0.31503119 0.29640823 0.28411759\n", - " 0.26972922 0.25610475 0.24632936 0.24610228 0.24733299 0.24426536\n", - " 0.23751781 0.22987273 0.22659963 0.2268427 0.23059031 0.23420722\n", - " 0.23771761 0.2411352 0.24404673 0.24557175 0.24724932 0.25055198\n", - " 0.25482755 0.2602407 0.26363878 0.26503898 0.2648467 0.26435072\n", - " 0.26353625 0.26364794 0.26411054 0.26339948 0.26212082 0.26146597\n", - " 0.26196556 0.26365859 0.26592959 0.26963884 0.27392766 0.27818809\n", - " 0.28313664 0.2863325 0.28713431 0.28649323 0.28636648 0.2867843\n", - " 0.28635904 0.28562022 0.28492711 0.28429201 0.28402977 0.28401045\n", - " 0.28560797 0.28728033 0.28969549 0.29351627 0.29826453 0.30572631\n", - " 0.31811682 0.32887739 0.33288219 0.33326245 0.33014147 0.32403202\n", - " 0.31903576 0.31316258 0.30741037 0.30370692 0.30204833 0.30049064\n", - " 0.29901079 0.29824511 0.29812308 0.29753329 0.29779342 0.30175296\n", - " 0.30955538 0.32904205]\n" - ] - } - ], - "source": [ - "import json\n", - "import math\n", - "import numpy as np\n", - "def _load_json_cmvn(json_cmvn_file):\n", - " \"\"\" Load the json format cmvn stats file and calculate cmvn\n", - "\n", - " Args:\n", - " json_cmvn_file: cmvn stats file in json format\n", - "\n", - " Returns:\n", - " a numpy array of [means, vars]\n", - " \"\"\"\n", - " with open(json_cmvn_file) as f:\n", - " cmvn_stats = json.load(f)\n", - "\n", - " means = cmvn_stats['mean_stat']\n", - " variance = cmvn_stats['var_stat']\n", - " count = cmvn_stats['frame_num']\n", - " for i in range(len(means)):\n", - " means[i] /= count\n", - " variance[i] = variance[i] / count - means[i] * means[i]\n", - " if variance[i] < 1.0e-20:\n", - " variance[i] = 1.0e-20\n", - " variance[i] = 1.0 / math.sqrt(variance[i])\n", - " cmvn = np.array([means, variance])\n", - " return cmvn\n", - "\n", - "\n", - "def _load_kaldi_cmvn(kaldi_cmvn_file):\n", - " \"\"\" Load the kaldi format cmvn stats file and calculate cmvn\n", - "\n", - " Args:\n", - " kaldi_cmvn_file: kaldi text style global cmvn file, which\n", - " is generated by:\n", - " compute-cmvn-stats --binary=false scp:feats.scp global_cmvn\n", - "\n", - " Returns:\n", - " a numpy array of [means, vars]\n", - " \"\"\"\n", - " means = []\n", - " variance = []\n", - " with open(kaldi_cmvn_file, 'r') as fid:\n", - " # kaldi binary file start with '\\0B'\n", - " if fid.read(2) == '\\0B':\n", - " logger.error('kaldi cmvn binary file is not supported, please '\n", - " 'recompute it by: compute-cmvn-stats --binary=false '\n", - " ' scp:feats.scp global_cmvn')\n", - " sys.exit(1)\n", - " fid.seek(0)\n", - " arr = fid.read().split()\n", - " assert (arr[0] == '[')\n", - " assert (arr[-2] == '0')\n", - " assert (arr[-1] == ']')\n", - " feat_dim = int((len(arr) - 2 - 2) / 2)\n", - " for i in range(1, feat_dim + 1):\n", - " means.append(float(arr[i]))\n", - " count = float(arr[feat_dim + 1])\n", - " for i in range(feat_dim + 2, 2 * feat_dim + 2):\n", - " variance.append(float(arr[i]))\n", - "\n", - " for i in range(len(means)):\n", - " means[i] /= count\n", - " variance[i] = variance[i] / count - means[i] * means[i]\n", - " if variance[i] < 1.0e-20:\n", - " variance[i] = 1.0e-20\n", - " variance[i] = 1.0 / math.sqrt(variance[i])\n", - " cmvn = np.array([means, variance])\n", - " return cmvn\n", - "\n", - "\n", - "def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):\n", - " npzfile = np.load(npz_cmvn_file)\n", - " means = npzfile[\"mean\"] #(1, D)\n", - " std = npzfile[\"std\"] #(1, D)\n", - " std = np.clip(std, eps, None)\n", - " variance = 1.0 / std\n", - " cmvn = np.array([means, variance])\n", - " return cmvn\n", - "\n", - "\n", - "def load_cmvn(cmvn_file: str, filetype: str):\n", - " \"\"\"load cmvn from file.\n", - "\n", - " Args:\n", - " cmvn_file (str): cmvn path.\n", - " filetype (str): file type, optional[npz, json, kaldi].\n", - "\n", - " Raises:\n", - " ValueError: file type not support.\n", - "\n", - " Returns:\n", - " Tuple[np.ndarray, np.ndarray]: mean, istd\n", - " \"\"\"\n", - " assert filetype in ['npz', 'json', 'kaldi'], filetype\n", - " filetype = filetype.lower()\n", - " if filetype == \"json\":\n", - " cmvn = _load_json_cmvn(cmvn_file)\n", - " elif filetype == \"kaldi\":\n", - " cmvn = _load_kaldi_cmvn(cmvn_file)\n", - " elif filetype == \"npz\":\n", - " cmvn = _load_npz_cmvn(cmvn_file)\n", - " else:\n", - " raise ValueError(f\"cmvn file type no support: {filetype}\")\n", - " return cmvn[0], cmvn[1]\n", - "\n", - "mean, istd = load_cmvn(args.cmvn, 'json')\n", - "print(mean.shape)\n", - "print(istd.shape)\n", - "print(mean)\n", - "print(istd)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "3cfa5e23", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ASRModel(\n", - " (encoder): ConformerEncoder(\n", - " (global_cmvn): GlobalCMVN()\n", - " (embed): Conv2dSubsampling4(\n", - " (conv): Sequential(\n", - " (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2))\n", - " (1): ReLU()\n", - " (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n", - " (3): ReLU()\n", - " )\n", - " (out): Sequential(\n", - " (0): Linear(in_features=4864, out_features=256, bias=True)\n", - " )\n", - " (pos_enc): RelPositionalEncoding(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (encoders): ModuleList(\n", - " (0): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (1): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (2): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (3): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (4): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (5): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (6): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (7): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (8): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (9): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (10): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (11): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " )\n", - " )\n", - " (decoder): TransformerDecoder(\n", - " (embed): Sequential(\n", - " (0): Embedding(4233, 256)\n", - " (1): PositionalEncoding(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (output_layer): Linear(in_features=256, out_features=4233, bias=True)\n", - " (decoders): ModuleList(\n", - " (0): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (1): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (2): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (3): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (4): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (5): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " )\n", - " )\n", - " (ctc): CTC(\n", - " (ctc_lo): Linear(in_features=256, out_features=4233, bias=True)\n", - " (ctc_loss): CTCLoss()\n", - " )\n", - " (criterion_att): LabelSmoothingLoss(\n", - " (criterion): KLDivLoss()\n", - " )\n", - ")\n" - ] - } - ], - "source": [ - "# Init asr model from configs\n", - "model = init_asr_model(configs)\n", - "print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "3c780af5", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "def summary(layer, print_func=print):\n", - " num_params = num_elements = 0\n", - " for name, param in layer.state_dict().items():\n", - " if print_func:\n", - " print_func(\n", - " \"{} | {} | {}\".format(name, param.shape, np.prod(param.shape)))\n", - " num_elements += np.prod(param.shape)\n", - " num_params += 1\n", - " if print_func:\n", - " print_func(\n", - " f\"Total parameters: {num_params}, {num_elements} elements.\"\n", - " )\n", - " \n", - "def print_params(model, print_func=print):\n", - " if print_func is None:\n", - " return\n", - " total = 0.0\n", - " num_params = 0.0\n", - " for n, p in model.named_parameters():\n", - " msg = f\"{n} | {p.shape} | {np.prod(p.shape)} | {p.requires_grad}\"\n", - " total += np.prod(p.shape)\n", - " num_params += 1\n", - " if print_func:\n", - " print_func(msg)\n", - " if print_func:\n", - " print_func(f\"Total parameters: {num_params}, {total} elements.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "e159a200", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.global_cmvn.mean | torch.Size([80]) | 80\n", - "encoder.global_cmvn.istd | torch.Size([80]) | 80\n", - "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304\n", - "encoder.embed.conv.0.bias | torch.Size([256]) | 256\n", - "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824\n", - "encoder.embed.conv.2.bias | torch.Size([256]) | 256\n", - "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184\n", - "encoder.embed.out.0.bias | torch.Size([256]) | 256\n", - "encoder.after_norm.weight | torch.Size([256]) | 256\n", - "encoder.after_norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256\n", - "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648\n", - "decoder.after_norm.weight | torch.Size([256]) | 256\n", - "decoder.after_norm.bias | torch.Size([256]) | 256\n", - "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648\n", - "decoder.output_layer.bias | torch.Size([4233]) | 4233\n", - "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256\n", - "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648\n", - "ctc.ctc_lo.bias | torch.Size([4233]) | 4233\n", - "Total parameters: 701, 49355454.0 elements.\n" - ] - } - ], - "source": [ - "summary(model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8494c6ab", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "0648a969", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304 | True\n", - "encoder.embed.conv.0.bias | torch.Size([256]) | 256 | True\n", - "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824 | True\n", - "encoder.embed.conv.2.bias | torch.Size([256]) | 256 | True\n", - "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184 | True\n", - "encoder.embed.out.0.bias | torch.Size([256]) | 256 | True\n", - "encoder.after_norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.after_norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648 | True\n", - "decoder.after_norm.weight | torch.Size([256]) | 256 | True\n", - "decoder.after_norm.bias | torch.Size([256]) | 256 | True\n", - "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648 | True\n", - "decoder.output_layer.bias | torch.Size([4233]) | 4233 | True\n", - "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648 | True\n", - "ctc.ctc_lo.bias | torch.Size([4233]) | 4233 | True\n", - "Total parameters: 663.0, 49349138.0 elements.\n" - ] - } - ], - "source": [ - "print_params(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "5ad6de2a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n", - "torch.Size([16, 207, 80])\n", - "tensor([[[ 8.9946, 9.5383, 9.1916, ..., 10.5074, 9.5633, 8.2564],\n", - " [ 9.7988, 10.4052, 9.2651, ..., 10.2512, 9.5440, 8.8738],\n", - " [10.6891, 10.3955, 8.0535, ..., 9.9067, 10.0649, 8.0509],\n", - " ...,\n", - " [ 9.2180, 9.6507, 8.5053, ..., 9.6872, 8.7425, 7.9865],\n", - " [10.1291, 9.9352, 9.3798, ..., 9.5639, 9.8260, 8.9795],\n", - " [ 9.0955, 7.1338, 9.4680, ..., 9.4727, 9.0212, 7.4479]],\n", - "\n", - " [[11.4310, 10.6719, 6.0841, ..., 9.3827, 8.7297, 7.5316],\n", - " [ 9.7317, 7.8105, 7.5715, ..., 10.0430, 9.2436, 7.3541],\n", - " [10.6502, 10.6006, 8.4678, ..., 9.2814, 9.1869, 8.0703],\n", - " ...,\n", - " [ 9.0970, 9.2637, 8.0753, ..., 8.4318, 8.3705, 8.0029],\n", - " [10.4617, 10.1478, 6.7693, ..., 9.7794, 9.5775, 8.0807],\n", - " [ 7.7944, 5.6211, 7.9751, ..., 9.9972, 9.8497, 8.0313]],\n", - "\n", - " [[ 7.3456, 7.8964, 7.5796, ..., 11.6310, 10.4513, 9.1236],\n", - " [ 8.6287, 8.4631, 7.4992, ..., 12.4160, 10.9757, 8.9426],\n", - " [ 9.8314, 10.2813, 8.9724, ..., 12.1387, 10.4017, 9.0055],\n", - " ...,\n", - " [ 7.0896, 7.4055, 6.8143, ..., 9.3252, 9.2732, 8.3534],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - " ...,\n", - "\n", - " [[10.9332, 10.4644, 7.7203, ..., 10.3488, 9.3023, 7.1553],\n", - " [10.4499, 9.9070, 9.0293, ..., 9.9525, 9.4141, 7.5593],\n", - " [10.4877, 9.8126, 9.8952, ..., 9.5866, 9.3413, 7.7849],\n", - " ...,\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - " [[ 9.9444, 9.5859, 8.2203, ..., 11.5886, 11.0450, 8.8171],\n", - " [ 7.6784, 8.3224, 7.5330, ..., 11.0551, 10.5357, 9.2746],\n", - " [ 8.6262, 9.6759, 9.8410, ..., 11.3788, 10.9221, 8.9914],\n", - " ...,\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - " [[ 8.1079, 7.7590, 6.7103, ..., 12.6506, 11.4662, 11.0615],\n", - " [11.3803, 11.2220, 8.6589, ..., 12.8106, 12.2222, 11.6893],\n", - " [10.6777, 9.9206, 8.0461, ..., 13.5729, 12.5624, 11.1550],\n", - " ...,\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n", - "tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n", - " 166, 163], dtype=torch.int32)\n", - "tensor([[2995, 3116, 1209, 565, -1, -1],\n", - " [ 236, 1176, 331, 66, 3925, 4077],\n", - " [2693, 524, 234, 1145, 366, -1],\n", - " [3875, 4211, 3062, 700, -1, -1],\n", - " [ 272, 987, 1134, 494, 2959, -1],\n", - " [1936, 3715, 120, 2553, 2695, 2710],\n", - " [ 25, 1149, 3930, -1, -1, -1],\n", - " [1753, 1778, 1237, 482, 3925, 110],\n", - " [3703, 2, 565, 3827, -1, -1],\n", - " [1150, 2734, 10, 2478, 3490, -1],\n", - " [ 426, 811, 95, 489, 144, -1],\n", - " [2313, 2006, 489, 975, -1, -1],\n", - " [3702, 3414, 205, 1488, 2966, 1347],\n", - " [ 70, 1741, 702, 1666, -1, -1],\n", - " [ 703, 1778, 1030, 849, -1, -1],\n", - " [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n", - "tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)\n" - ] - } - ], - "source": [ - "for batch in cv_data_loader:\n", - " keys, feat, text, feat_len, text_len = batch\n", - " print(keys)\n", - " print(feat.shape)\n", - " print(feat)\n", - " print(feat_len)\n", - " print(text)\n", - " print(text_len)\n", - " np.savez('data.npz', keys=keys, feat=feat.numpy(), feat_len=feat_len.numpy(), text=text.numpy(), text_len=text_len.numpy())\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "852a9c95", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CODE_OF_CONDUCT.md data.npz install.sh README.md\t tools\r\n", - "CONTRIBUTING.md docs LICENSE\t requirements.txt venv\r\n", - "CPPLINT.cfg\t examples Makefile\t runtime\t wenet\r\n" - ] - } - ], - "source": [ - "!ls\n", - "!cp data.npz /workspace/DeepSpeech-2.x/.notebook" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "cde24c4e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(111.9988)\n", - "tensor(830.9634, grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True])\n", - "tensor(669.4633, grad_fn=)\n", - "tensor(142.4888, grad_fn=) tensor(41.8415, grad_fn=) tensor(377.3326, grad_fn=)\n" - ] - } - ], - "source": [ - "model.cpu().eval()\n", - "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", - " text, text_len)\n", - "print(total_loss, attention_loss, ctc_loss )" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "be5b2a2c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cpu\n" - ] - } - ], - "source": [ - "print(total_loss.device)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "5b791771", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(112., device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "cuda:0\n", - "142.4888 41.84146 377.33258\n" - ] - } - ], - "source": [ - "model.cuda().eval()\n", - "feat=feat.cuda()\n", - "feat_len=feat_len.cuda()\n", - "text=text.cuda()\n", - "text_len=text_len.cuda()\n", - "\n", - "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", - " text, text_len)\n", - "print(total_loss.device)\n", - "print(total_loss.cpu().data.numpy(), attention_loss.cpu().data.numpy(), ctc_loss.cpu().data.numpy() )" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "1baef537", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([16, 51, 256])\n", - "torch.Size([16, 1, 51])\n", - "tensor([[-0.7019, 0.5625, 0.6880, ..., 1.1237, 0.7804, 1.1369],\n", - " [-0.7788, 0.3913, 0.7189, ..., 1.2519, 0.8862, 1.3173],\n", - " [-0.9591, 0.6346, 0.8767, ..., 0.9818, 0.7440, 1.2903],\n", - " ...,\n", - " [-1.0732, 0.6724, 0.9230, ..., 0.9075, 0.8177, 1.3240],\n", - " [-1.1654, 0.6820, 0.6939, ..., 1.2238, 0.8028, 1.4507],\n", - " [-1.2732, 0.7146, 0.7582, ..., 0.9415, 0.8775, 1.2623]],\n", - " device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n", - "print(encoder_out.shape)\n", - "print(encoder_mask.shape)\n", - "print(encoder_out[0])\n", - "\n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/encoder.npz',\n", - " mask=encoder_mask.cpu().detach().numpy(), \n", - " out=encoder_out.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e22c782", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "30b6b946", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 9.871763 9.938915 10.238187 10.8597145 11.686526 12.25488\n", - " 12.657681 12.86139 12.807339 12.566256 12.32007 12.138792\n", - " 12.313189 12.552552 12.612239 12.569745 12.389728 12.143833\n", - " 12.092851 11.793959 11.622591 11.926331 11.815442 11.951225\n", - " 11.831805 11.887888 11.790144 11.88072 11.900057 11.973481\n", - " 12.009822 12.008814 12.026197 12.104796 12.21555 12.343993\n", - " 12.450144 12.496688 12.486538 12.355079 12.392918 12.255374\n", - " 12.264963 12.253142 12.325458 12.4335985 12.548675 12.676334\n", - " 12.809207 12.929347 12.961151 12.968834 12.995931 13.047281\n", - " 13.058881 13.05738 12.999211 12.934022 12.874292 12.71653\n", - " 12.48942 12.274784 12.261631 12.286319 12.31956 12.422907\n", - " 12.514802 12.578516 12.647194 12.737626 12.800171 12.868728\n", - " 12.966668 13.064786 13.159159 13.272843 13.310819 13.239043\n", - " 12.879361 11.183102 ] float32\n", - "encoder.embed.out.0.weight: (256, 4864) -> (4864, 256)\n", - "encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.0.conv_module.norm.running_mean -> encoder.encoders.0.conv_module.norm._mean\n", - "encoder.encoders.0.conv_module.norm.running_var -> encoder.encoders.0.conv_module.norm._variance\n", - "encoder.encoders.0.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.1.conv_module.norm.running_mean -> encoder.encoders.1.conv_module.norm._mean\n", - "encoder.encoders.1.conv_module.norm.running_var -> encoder.encoders.1.conv_module.norm._variance\n", - "encoder.encoders.1.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.2.conv_module.norm.running_mean -> encoder.encoders.2.conv_module.norm._mean\n", - "encoder.encoders.2.conv_module.norm.running_var -> encoder.encoders.2.conv_module.norm._variance\n", - "encoder.encoders.2.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.3.conv_module.norm.running_mean -> encoder.encoders.3.conv_module.norm._mean\n", - "encoder.encoders.3.conv_module.norm.running_var -> encoder.encoders.3.conv_module.norm._variance\n", - "encoder.encoders.3.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.4.conv_module.norm.running_mean -> encoder.encoders.4.conv_module.norm._mean\n", - "encoder.encoders.4.conv_module.norm.running_var -> encoder.encoders.4.conv_module.norm._variance\n", - "encoder.encoders.4.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.5.conv_module.norm.running_mean -> encoder.encoders.5.conv_module.norm._mean\n", - "encoder.encoders.5.conv_module.norm.running_var -> encoder.encoders.5.conv_module.norm._variance\n", - "encoder.encoders.5.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.6.conv_module.norm.running_mean -> encoder.encoders.6.conv_module.norm._mean\n", - "encoder.encoders.6.conv_module.norm.running_var -> encoder.encoders.6.conv_module.norm._variance\n", - "encoder.encoders.6.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.7.conv_module.norm.running_mean -> encoder.encoders.7.conv_module.norm._mean\n", - "encoder.encoders.7.conv_module.norm.running_var -> encoder.encoders.7.conv_module.norm._variance\n", - "encoder.encoders.7.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.8.conv_module.norm.running_mean -> encoder.encoders.8.conv_module.norm._mean\n", - "encoder.encoders.8.conv_module.norm.running_var -> encoder.encoders.8.conv_module.norm._variance\n", - "encoder.encoders.8.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.9.conv_module.norm.running_mean -> encoder.encoders.9.conv_module.norm._mean\n", - "encoder.encoders.9.conv_module.norm.running_var -> encoder.encoders.9.conv_module.norm._variance\n", - "encoder.encoders.9.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.10.conv_module.norm.running_mean -> encoder.encoders.10.conv_module.norm._mean\n", - "encoder.encoders.10.conv_module.norm.running_var -> encoder.encoders.10.conv_module.norm._variance\n", - "encoder.encoders.10.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.11.conv_module.norm.running_mean -> encoder.encoders.11.conv_module.norm._mean\n", - "encoder.encoders.11.conv_module.norm.running_var -> encoder.encoders.11.conv_module.norm._variance\n", - "encoder.encoders.11.concat_linear.weight: (256, 512) -> (512, 256)\n", - "decoder.output_layer.weight: (4233, 256) -> (256, 4233)\n", - "decoder.decoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.0.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.0.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.1.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.1.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.2.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.2.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.3.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.3.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.4.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.4.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.5.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.5.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "ctc.ctc_lo.weight: (4233, 256) -> (256, 4233)\n" - ] - } - ], - "source": [ - "# dump torch model to paddle\n", - "import numpy as np\n", - "state_dict = model.state_dict()\n", - "paddle_state_dict = {}\n", - "\n", - "for n, p in state_dict.items():\n", - " name_change=True\n", - "\n", - " if 'norm.running_mean' in n:\n", - " new_n = n.replace('norm.running_', 'norm._')\n", - " elif 'norm.running_var' in n:\n", - " new_n = n.replace('norm.running_var', 'norm._variance')\n", - " else:\n", - " name_change=False\n", - " new_n = n\n", - " \n", - " if name_change:\n", - " print(f\"{n} -> {new_n}\")\n", - " \n", - " p = p.cpu().detach().numpy()\n", - " if n.endswith('weight') and p.ndim == 2 and 'embed.0.weight' not in n:\n", - " new_p = p.T\n", - " print(f\"{n}: {p.shape} -> {new_p.shape}\")\n", - " else:\n", - " new_p = p\n", - " \n", - " if 'global_cmvn.mean' in n:\n", - " print(p, p.dtype)\n", - " \n", - " paddle_state_dict[new_n] = new_p\n", - " \n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n", - " state=paddle_state_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7307dc5b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "d99b29bc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(377.3326, device='cuda:0', grad_fn=)\n", - "None\n", - "[[ 3.16902351e+00 -1.51765049e-02 4.91097234e-02 ... -2.47973716e-03\n", - " -5.93366381e-03 -7.26613170e-03]\n", - " [-1.74185038e+00 7.75875803e-03 -4.49435972e-02 ... 9.92415240e-04\n", - " 2.46338220e-03 2.31891591e-03]\n", - " [-2.33343077e+00 1.30476682e-02 -2.66557615e-02 ... 2.27533933e-03\n", - " 5.76929189e-03 7.48792710e-03]\n", - " ...\n", - " [-4.30356789e+00 2.46056803e-02 -9.00955945e-02 ... 4.43160534e-03\n", - " 1.16123557e-02 1.44716976e-02]\n", - " [-3.36919212e+00 1.73155665e-02 -6.36875406e-02 ... 3.28367390e-03\n", - " 8.58021621e-03 1.07796099e-02]\n", - " [-6.62039661e+00 3.49958315e-02 -1.23963736e-01 ... 6.36674836e-03\n", - " 1.60815325e-02 2.03892551e-02]]\n", - "[-4.3777566e+00 2.3245990e-02 -9.3339972e-02 ... 4.2569702e-03\n", - " 1.0920014e-02 1.3787906e-02]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":6: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.\n", - " print(loss_ctc.grad)\n" - ] - } - ], - "source": [ - "encoder_out_lens = encoder_mask.squeeze(1).sum(1)\n", - "loss_ctc = model.ctc(encoder_out, encoder_out_lens, text, text_len)\n", - "print(loss_ctc)\n", - "dir(loss_ctc)\n", - "loss_ctc.backward()\n", - "print(loss_ctc.grad)\n", - "#print(model.ctc.ctc_lo.weight.grad)\n", - "print(model.ctc.ctc_lo.weight.grad.T.cpu().data.numpy())\n", - "print(model.ctc.ctc_lo.bias.grad.cpu().data.numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "49b05d6d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(112., device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "tensor(41.8415, device='cuda:0', grad_fn=) 0.0\n" - ] - } - ], - "source": [ - "loss_att, acc_att = model._calc_att_loss(encoder_out, encoder_mask,\n", - " text, text_len)\n", - "print(loss_att, acc_att)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "413b413f", - "metadata": {}, - "outputs": [], - "source": [ - "def pad_list(xs, pad_value: int):\n", - " n_batch = len(xs)\n", - " max_len = max([x.size(0) for x in xs])\n", - " pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)\n", - " pad = pad.fill_(pad_value)\n", - " for i in range(n_batch):\n", - " pad[i, :xs[i].size(0)] = xs[i]\n", - "\n", - " return pad\n", - "\n", - "def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,\n", - " ignore_id: int):\n", - "\n", - " _sos = torch.tensor([sos],\n", - " dtype=torch.long,\n", - " requires_grad=False,\n", - " device=ys_pad.device)\n", - " _eos = torch.tensor([eos],\n", - " dtype=torch.long,\n", - " requires_grad=False,\n", - " device=ys_pad.device)\n", - " ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n", - " ys_in = [torch.cat([_sos, y], dim=0) for y in ys]\n", - " ys_out = [torch.cat([y, _eos], dim=0) for y in ys]\n", - " return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "ff0c2400", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[4232, 2995, 3116, 1209, 565, 4232, 4232],\n", - " [4232, 236, 1176, 331, 66, 3925, 4077],\n", - " [4232, 2693, 524, 234, 1145, 366, 4232],\n", - " [4232, 3875, 4211, 3062, 700, 4232, 4232],\n", - " [4232, 272, 987, 1134, 494, 2959, 4232],\n", - " [4232, 1936, 3715, 120, 2553, 2695, 2710],\n", - " [4232, 25, 1149, 3930, 4232, 4232, 4232],\n", - " [4232, 1753, 1778, 1237, 482, 3925, 110],\n", - " [4232, 3703, 2, 565, 3827, 4232, 4232],\n", - " [4232, 1150, 2734, 10, 2478, 3490, 4232],\n", - " [4232, 426, 811, 95, 489, 144, 4232],\n", - " [4232, 2313, 2006, 489, 975, 4232, 4232],\n", - " [4232, 3702, 3414, 205, 1488, 2966, 1347],\n", - " [4232, 70, 1741, 702, 1666, 4232, 4232],\n", - " [4232, 703, 1778, 1030, 849, 4232, 4232],\n", - " [4232, 814, 1674, 115, 3827, 4232, 4232]], device='cuda:0')\n", - "tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n", - " [ 236, 1176, 331, 66, 3925, 4077, 4232],\n", - " [2693, 524, 234, 1145, 366, 4232, -1],\n", - " [3875, 4211, 3062, 700, 4232, -1, -1],\n", - " [ 272, 987, 1134, 494, 2959, 4232, -1],\n", - " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", - " [ 25, 1149, 3930, 4232, -1, -1, -1],\n", - " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", - " [3703, 2, 565, 3827, 4232, -1, -1],\n", - " [1150, 2734, 10, 2478, 3490, 4232, -1],\n", - " [ 426, 811, 95, 489, 144, 4232, -1],\n", - " [2313, 2006, 489, 975, 4232, -1, -1],\n", - " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", - " [ 70, 1741, 702, 1666, 4232, -1, -1],\n", - " [ 703, 1778, 1030, 849, 4232, -1, -1],\n", - " [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n" - ] - } - ], - "source": [ - "ys_pad = text\n", - "ys_pad_lens = text_len\n", - "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n", - " model.ignore_id)\n", - "ys_in_lens = ys_pad_lens + 1\n", - "print(ys_in_pad)\n", - "print(ys_out_pad)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "3e84da38", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([16, 7, 4233])\n", - "tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n", - " 1.5035e-02, 4.0337e-01],\n", - " [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n", - " -1.4353e-01, -1.0024e+00],\n", - " [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n", - " -1.1672e+00, -2.6849e-01],\n", - " ...,\n", - " [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n", - " 4.5470e-02, -3.7139e-01],\n", - " [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n", - " -7.9347e-04, 4.2538e-01],\n", - " [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n", - " -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", - " ys_in_lens)\n", - "print(decoder_out.shape)\n", - "print(decoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aac441ea", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "5ddbca73", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.float32\n", - "torch.int64\n", - "tensor(112., device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "tensor(41.8415, device='cuda:0', grad_fn=)\n", - "tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n", - " [ 236, 1176, 331, 66, 3925, 4077, 4232],\n", - " [2693, 524, 234, 1145, 366, 4232, -1],\n", - " [3875, 4211, 3062, 700, 4232, -1, -1],\n", - " [ 272, 987, 1134, 494, 2959, 4232, -1],\n", - " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", - " [ 25, 1149, 3930, 4232, -1, -1, -1],\n", - " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", - " [3703, 2, 565, 3827, 4232, -1, -1],\n", - " [1150, 2734, 10, 2478, 3490, 4232, -1],\n", - " [ 426, 811, 95, 489, 144, 4232, -1],\n", - " [2313, 2006, 489, 975, 4232, -1, -1],\n", - " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", - " [ 70, 1741, 702, 1666, 4232, -1, -1],\n", - " [ 703, 1778, 1030, 849, 4232, -1, -1],\n", - " [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n", - "tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n", - " 1.5035e-02, 4.0337e-01],\n", - " [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n", - " -1.4353e-01, -1.0024e+00],\n", - " [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n", - " -1.1672e+00, -2.6849e-01],\n", - " ...,\n", - " [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n", - " 4.5470e-02, -3.7139e-01],\n", - " [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n", - " -7.9347e-04, 4.2538e-01],\n", - " [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n", - " -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "print(decoder_out.dtype)\n", - "print(ys_out_pad.dtype)\n", - "loss_att = model.criterion_att(decoder_out, ys_out_pad)\n", - "print(loss_att)\n", - "print(ys_out_pad)\n", - "print(decoder_out[0])\n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/decoder',\n", - " decoder_out=decoder_out.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "78f98c0b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "8d968cd3", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from torch import nn\n", - "\n", - "\n", - "class LabelSmoothingLoss(nn.Module):\n", - " def __init__(self,\n", - " size: int,\n", - " padding_idx: int,\n", - " smoothing: float,\n", - " normalize_length: bool = False):\n", - " \"\"\"Construct an LabelSmoothingLoss object.\"\"\"\n", - " super(LabelSmoothingLoss, self).__init__()\n", - " self.criterion = nn.KLDivLoss(reduction=\"none\")\n", - " self.padding_idx = padding_idx\n", - " self.confidence = 1.0 - smoothing\n", - " self.smoothing = smoothing\n", - " self.size = size\n", - " self.normalize_length = normalize_length\n", - "\n", - " def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"Compute loss between x and target.\n", - "\n", - " The model outputs and data labels tensors are flatten to\n", - " (batch*seqlen, class) shape and a mask is applied to the\n", - " padding part which should not be calculated for loss.\n", - "\n", - " Args:\n", - " x (torch.Tensor): prediction (batch, seqlen, class)\n", - " target (torch.Tensor):\n", - " target signal masked with self.padding_id (batch, seqlen)\n", - " Returns:\n", - " loss (torch.Tensor) : The KL loss, scalar float value\n", - " \"\"\"\n", - " assert x.size(2) == self.size\n", - " batch_size = x.size(0)\n", - " x = x.view(-1, self.size)\n", - " target = target.view(-1)\n", - " # use zeros_like instead of torch.no_grad() for true_dist,\n", - " # since no_grad() can not be exported by JIT\n", - " true_dist = torch.zeros_like(x)\n", - " true_dist.fill_(self.smoothing / (self.size - 1))\n", - " ignore = target == self.padding_idx # (B,)\n", - " print(self.smoothing / (self.size - 1))\n", - " print(true_dist)\n", - " total = len(target) - ignore.sum().item()\n", - " target = target.masked_fill(ignore, 0) # avoid -1 index\n", - " true_dist.scatter_(1, target.unsqueeze(1), self.confidence)\n", - " print(true_dist.dtype)\n", - " print(true_dist.square().sum())\n", - " kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)\n", - " print(kl.sum())\n", - " denom = total if self.normalize_length else batch_size\n", - " print(ignore)\n", - " numer= kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n", - " print(numer)\n", - " return numer /denom" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "3df340ec", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.3629489603024576e-05\n", - "tensor([[2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " ...,\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05]], device='cuda:0')\n", - "torch.float32\n", - "tensor(90.7203, device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "tensor(41.8415, device='cuda:0', grad_fn=)\n", - "torch.int64\n" - ] - } - ], - "source": [ - "criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n", - "loss_att = criteron(decoder_out, ys_out_pad)\n", - "print(loss_att)\n", - "print(ys_out_pad.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "badc410d", - "metadata": {}, - "outputs": [ - { - "ename": "RuntimeError", - "evalue": "Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecoder_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \"\"\"\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m allow_unreachable=True) # allow_unreachable flag\n", - "\u001b[0;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time." - ] - } - ], - "source": [ - "loss_att.backward()\n", - "print(loss_att.grad)\n", - "print(decoder_out.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "219eb41f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([ 0.0024, 0.0019, -0.1098, ..., 0.0028, 0.0020, -1.7978],\n", - " device='cuda:0')\n", - "tensor([[ 6.5052e-04, 6.4419e-05, -6.1955e-06, ..., 9.8220e-04,\n", - " -2.5918e-05, 3.3754e-04],\n", - " [ 3.9305e-04, 4.5799e-04, 1.4362e-04, ..., 4.6800e-04,\n", - " 1.6911e-04, 2.7067e-04],\n", - " [-1.3593e-01, 5.2201e-02, 3.2895e-02, ..., 2.4580e-02,\n", - " 1.4590e-01, -4.6850e-02],\n", - " ...,\n", - " [ 1.0434e-03, 4.2251e-04, 6.5688e-04, ..., 1.2144e-03,\n", - " 2.1159e-04, 6.6838e-04],\n", - " [ 6.4997e-04, 4.4301e-04, 4.1550e-04, ..., 1.0420e-03,\n", - " 2.4114e-04, 1.5338e-04],\n", - " [-9.9337e-01, 5.4573e-01, -1.1371e-02, ..., -4.3175e-01,\n", - " -2.7850e-01, -4.4679e-01]], device='cuda:0')\n" - ] - } - ], - "source": [ - "print(model.decoder.output_layer.bias.grad)\n", - "print(model.decoder.output_layer.weight.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "40d00a54", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[-5.3698e-01, -1.9911e-01, -3.4997e-01, ..., -8.2428e-01,\n", - " -1.0265e+00, -9.6301e-01],\n", - " [-4.4642e-02, 2.3176e-01, -3.2539e-01, ..., -9.0159e-01,\n", - " -1.0325e+00, -7.5987e-01],\n", - " [ 5.0035e-01, 2.2691e-01, -7.3052e-01, ..., -1.0055e+00,\n", - " -8.7123e-01, -1.0306e+00],\n", - " ...,\n", - " [-4.0024e-01, -1.4325e-01, -5.7947e-01, ..., -1.0718e+00,\n", - " -1.2806e+00, -1.0518e+00],\n", - " [ 1.5755e-01, -1.8495e-03, -2.8703e-01, ..., -1.1090e+00,\n", - " -9.4519e-01, -7.2506e-01],\n", - " [-4.7520e-01, -1.3942e+00, -2.5754e-01, ..., -1.1365e+00,\n", - " -1.1943e+00, -1.2290e+00]],\n", - "\n", - " [[ 9.5454e-01, 3.6428e-01, -1.3891e+00, ..., -1.1637e+00,\n", - " -1.2845e+00, -1.2015e+00],\n", - " [-8.5735e-02, -1.0579e+00, -8.9173e-01, ..., -9.6441e-01,\n", - " -1.1255e+00, -1.2599e+00],\n", - " [ 4.7654e-01, 3.2887e-01, -5.9201e-01, ..., -1.1942e+00,\n", - " -1.1430e+00, -1.0242e+00],\n", - " ...,\n", - " [-4.7431e-01, -3.3559e-01, -7.2326e-01, ..., -1.4506e+00,\n", - " -1.3957e+00, -1.0464e+00],\n", - " [ 3.6113e-01, 1.0381e-01, -1.1599e+00, ..., -1.0439e+00,\n", - " -1.0221e+00, -1.0208e+00],\n", - " [-1.2717e+00, -2.1460e+00, -7.5677e-01, ..., -9.7822e-01,\n", - " -9.3785e-01, -1.0371e+00]],\n", - "\n", - " [[-1.5465e+00, -1.0152e+00, -8.8901e-01, ..., -4.8522e-01,\n", - " -7.5163e-01, -6.7765e-01],\n", - " [-7.6101e-01, -7.3352e-01, -9.1588e-01, ..., -2.4836e-01,\n", - " -5.8927e-01, -7.3723e-01],\n", - " [-2.4714e-02, 1.7016e-01, -4.2326e-01, ..., -3.3204e-01,\n", - " -7.6696e-01, -7.1652e-01],\n", - " ...,\n", - " [-1.7032e+00, -1.2591e+00, -1.1449e+00, ..., -1.1810e+00,\n", - " -1.1163e+00, -9.3108e-01],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[ 6.4983e-01, 2.6117e-01, -8.4197e-01, ..., -8.7213e-01,\n", - " -1.1073e+00, -1.3253e+00],\n", - " [ 3.5391e-01, -1.5846e-02, -4.0425e-01, ..., -9.9173e-01,\n", - " -1.0727e+00, -1.1924e+00],\n", - " [ 3.7704e-01, -6.2785e-02, -1.1468e-01, ..., -1.1021e+00,\n", - " -1.0952e+00, -1.1182e+00],\n", - " ...,\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]],\n", - "\n", - " [[ 4.4458e-02, -1.7547e-01, -6.7475e-01, ..., -4.9801e-01,\n", - " -5.6783e-01, -7.7852e-01],\n", - " [-1.3428e+00, -8.0343e-01, -9.0457e-01, ..., -6.5902e-01,\n", - " -7.2550e-01, -6.2796e-01],\n", - " [-7.6253e-01, -1.3071e-01, -1.3280e-01, ..., -5.6133e-01,\n", - " -6.0588e-01, -7.2115e-01],\n", - " ...,\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]],\n", - "\n", - " [[-1.0798e+00, -1.0834e+00, -1.1797e+00, ..., -1.7757e-01,\n", - " -4.3747e-01, -4.0007e-02],\n", - " [ 9.2354e-01, 6.3771e-01, -5.2810e-01, ..., -1.2928e-01,\n", - " -2.0342e-01, 1.6656e-01],\n", - " [ 4.9337e-01, -9.1133e-03, -7.3302e-01, ..., 1.0074e-01,\n", - " -9.8115e-02, -9.2357e-03],\n", - " ...,\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]]], device='cuda:0')\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "print(xs)" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "505ca294", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[ True, True, True, ..., True, True, True]],\n", - "\n", - " [[ True, True, True, ..., True, True, True]],\n", - "\n", - " [[ True, True, True, ..., True, False, False]],\n", - "\n", - " ...,\n", - "\n", - " [[ True, True, True, ..., False, False, False]],\n", - "\n", - " [[ True, True, True, ..., False, False, False]],\n", - "\n", - " [[ True, True, True, ..., False, False, False]]], device='cuda:0')\n" - ] - } - ], - "source": [ - "from wenet.utils.mask import make_pad_mask\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", - "print(masks)" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "aa03c2b9", - "metadata": {}, - "outputs": [], - "source": [ - "xs, pos_emb, masks = model.encoder.embed(xs, masks)" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "ebc0ea12", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[-0.5482, 2.2866, -1.0750, ..., 1.4504, 0.2895, -0.6945],\n", - " [-0.8013, 1.7688, -1.6639, ..., 1.8332, 0.6791, -0.2000],\n", - " [-1.7112, 2.7057, -1.3363, ..., 1.2336, 0.1870, -0.5735],\n", - " ...,\n", - " [-0.9697, 2.3129, -0.8752, ..., 0.8584, 0.4853, -0.4177],\n", - " [-1.3609, 2.1779, -1.7813, ..., 2.0928, 0.2528, -0.3650],\n", - " [-1.6967, 2.3544, -1.7417, ..., 1.3670, 0.5951, -0.7415]],\n", - "\n", - " [[-1.9828, 2.3178, -0.9079, ..., 0.4117, 0.5006, 0.0872],\n", - " [-0.7640, 1.3558, -1.3613, ..., 0.7317, 0.6784, 0.1685],\n", - " [-0.9504, 1.6038, -1.3030, ..., 0.5754, 0.2677, 0.3343],\n", - " ...,\n", - " [-1.4757, 2.5317, -1.2321, ..., 1.2997, 0.5019, -0.1034],\n", - " [-1.1731, 2.3172, -1.2542, ..., 1.7391, 0.2171, -0.4445],\n", - " [-1.2700, 3.2229, -0.8872, ..., 1.6461, 0.0973, -0.7679]],\n", - "\n", - " [[-0.5873, 1.4291, -1.3950, ..., 0.2102, 0.1027, 0.0918],\n", - " [ 0.1743, 1.7834, -1.6422, ..., 0.8113, 0.3137, 0.5634],\n", - " [-0.3492, 1.8310, -1.0685, ..., 0.6924, 0.1378, 0.4594],\n", - " ...,\n", - " [-1.0869, 2.3002, -1.2638, ..., 1.7998, 0.5134, -0.5223],\n", - " [-1.2614, 2.7240, -1.3734, ..., 1.4445, 0.5742, -0.3320],\n", - " [-2.2068, 4.3462, -3.8289, ..., 2.1426, 1.2034, -1.3795]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.3914, 1.8553, -0.5747, ..., 1.0062, 0.4632, -1.0452],\n", - " [-0.8605, 2.0172, -1.4437, ..., 1.4526, 0.1657, 0.5923],\n", - " [-0.7307, 2.2841, -1.0699, ..., 1.5825, -0.0980, 0.5503],\n", - " ...,\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n", - "\n", - " [[-0.1619, 0.6255, -1.1323, ..., 0.0724, -0.2204, 0.4636],\n", - " [-0.0831, 0.5750, -1.0930, ..., 0.9110, -0.0650, 0.7299],\n", - " [-0.2820, 0.0801, -0.9418, ..., 0.3379, -0.1166, 0.4451],\n", - " ...,\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n", - "\n", - " [[-0.5458, -0.6909, -1.3597, ..., -0.7818, 0.6875, 0.9843],\n", - " [ 0.0421, -1.1062, -1.4389, ..., -0.0239, 0.9115, 0.5287],\n", - " [-0.2909, -0.1886, -1.5487, ..., -0.1392, 0.0580, 0.3066],\n", - " ...,\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,\n", - " 0.0000e+00, 1.0000e+00],\n", - " [ 8.4147e-01, 5.4030e-01, 8.0196e-01, ..., 1.0000e+00,\n", - " 1.0746e-04, 1.0000e+00],\n", - " [ 9.0930e-01, -4.1615e-01, 9.5814e-01, ..., 1.0000e+00,\n", - " 2.1492e-04, 1.0000e+00],\n", - " ...,\n", - " [-7.6825e-01, -6.4014e-01, 6.3280e-01, ..., 9.9998e-01,\n", - " 5.1581e-03, 9.9999e-01],\n", - " [-9.5375e-01, 3.0059e-01, 9.9899e-01, ..., 9.9998e-01,\n", - " 5.2656e-03, 9.9999e-01],\n", - " [-2.6237e-01, 9.6497e-01, 5.6075e-01, ..., 9.9998e-01,\n", - " 5.3730e-03, 9.9999e-01]]], device='cuda:0')\n", - "tensor([[[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, False, False, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, False, False, False, False, False, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, False, False, False, False, False, False, False, False, False,\n", - " False]]], device='cuda:0')\n", - "torch.Size([16, 1, 51])\n" - ] - } - ], - "source": [ - "print(xs)\n", - "print(pos_emb)\n", - "print(masks)\n", - "print(masks.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "4289461b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n", - " -0.6945408 ]\n", - " [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n", - " -0.1999542 ]\n", - " [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n", - " -0.5735198 ]\n", - " ...\n", - " [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n", - " -0.41773027]\n", - " [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n", - " -0.36496443]\n", - " [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n", - " -0.74147725]]\n", - "\n", - " [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n", - " 0.08721463]\n", - " [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n", - " 0.16851945]\n", - " [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n", - " 0.33433008]\n", - " ...\n", - " [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n", - " -0.10343577]\n", - " [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n", - " -0.44447583]\n", - " [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n", - " -0.7678688 ]]\n", - "\n", - " [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n", - " 0.09179455]\n", - " [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n", - " 0.56344515]\n", - " [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n", - " 0.45937473]\n", - " ...\n", - " [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n", - " -0.52227837]\n", - " [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n", - " -0.33201432]\n", - " [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n", - " -1.3795122 ]]\n", - "\n", - " ...\n", - "\n", - " [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n", - " -1.045236 ]\n", - " [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n", - " 0.5923172 ]\n", - " [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n", - " 0.55030036]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n", - " 0.46362036]\n", - " [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n", - " 0.72986233]\n", - " [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n", - " 0.44514441]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n", - " 0.9842716 ]\n", - " [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n", - " 0.52870303]\n", - " [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n", - " 0.30663735]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]]\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n", - "print(xs.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "67e10d73", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 2.0908e-03],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 1.1943e-02, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 4.6105e-02, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 9.6723e-03,\n", - " 4.6135e-02, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[2.2816e-01, 2.4615e-01, 2.5304e-01, ..., 2.0402e-01,\n", - " 2.3248e-01, 3.1191e-01],\n", - " [1.3587e-01, 2.8877e-01, 2.7991e-01, ..., 1.9210e-01,\n", - " 2.0346e-01, 1.9934e-01],\n", - " [2.5739e-01, 3.9348e-01, 2.7877e-01, ..., 2.7483e-01,\n", - " 1.9302e-01, 2.3810e-01],\n", - " ...,\n", - " [1.1939e-01, 2.8473e-01, 3.3082e-01, ..., 2.3838e-01,\n", - " 2.2104e-01, 2.3906e-01],\n", - " [1.7388e-01, 2.0402e-01, 4.0263e-01, ..., 2.4782e-01,\n", - " 2.6742e-01, 1.5427e-01],\n", - " [0.0000e+00, 2.9081e-01, 2.7726e-01, ..., 1.7540e-01,\n", - " 1.8479e-01, 2.2483e-01]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.5447e-01, 3.8861e-01, 3.9724e-01, ..., 3.8680e-01,\n", - " 3.3568e-01, 3.4552e-01],\n", - " [4.1739e-01, 5.1039e-01, 4.1730e-01, ..., 3.3993e-01,\n", - " 3.7082e-01, 3.5110e-01],\n", - " [3.6117e-01, 4.0745e-01, 4.8491e-01, ..., 3.4849e-01,\n", - " 3.2321e-01, 3.5189e-01],\n", - " ...,\n", - " [2.3144e-01, 3.8021e-01, 5.1526e-01, ..., 3.6499e-01,\n", - " 3.7412e-01, 3.9986e-01],\n", - " [3.4679e-01, 4.0238e-01, 5.0077e-01, ..., 3.6185e-01,\n", - " 3.1597e-01, 3.6335e-01],\n", - " [3.6498e-01, 3.7943e-01, 5.1719e-01, ..., 3.1798e-01,\n", - " 3.3657e-01, 3.4130e-01]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.4560e-02, 9.4475e-02, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5002e-02, 2.9632e-02, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.2952e-02, 0.0000e+00, 0.0000e+00, ..., 4.5850e-02,\n", - " 2.0439e-02, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 4.4258e-02],\n", - " [0.0000e+00, 0.0000e+00, 2.5565e-02, ..., 0.0000e+00,\n", - " 9.0044e-03, 4.9084e-02]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.1141e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[3.3697e-01, 3.8527e-01, 3.2900e-01, ..., 2.8704e-01,\n", - " 2.3351e-01, 1.9004e-01],\n", - " [1.3575e-01, 3.5783e-01, 3.3573e-01, ..., 2.2082e-01,\n", - " 1.5855e-01, 1.3587e-01],\n", - " [2.1929e-01, 2.8900e-01, 2.8255e-01, ..., 2.0603e-01,\n", - " 2.3927e-01, 2.1909e-01],\n", - " ...,\n", - " [2.3292e-01, 3.9097e-01, 3.6399e-01, ..., 2.0598e-01,\n", - " 2.5374e-01, 2.3137e-01],\n", - " [1.8739e-01, 3.0794e-01, 3.0297e-01, ..., 2.7251e-01,\n", - " 2.5192e-01, 2.0837e-01],\n", - " [2.2454e-01, 4.1402e-01, 5.4083e-01, ..., 3.1875e-01,\n", - " 2.5080e-01, 2.5939e-01]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.6457e-01, 4.9519e-01, 5.6702e-01, ..., 3.0955e-01,\n", - " 3.5292e-01, 3.2669e-01],\n", - " [2.1577e-01, 5.1833e-01, 4.9183e-01, ..., 3.6043e-01,\n", - " 3.8524e-01, 3.6155e-01],\n", - " [2.0068e-01, 4.2784e-01, 5.2818e-01, ..., 3.1871e-01,\n", - " 3.2452e-01, 3.1036e-01],\n", - " ...,\n", - " [4.9855e-01, 5.1001e-01, 5.2279e-01, ..., 3.6450e-01,\n", - " 3.4338e-01, 3.3603e-01],\n", - " [4.1233e-01, 5.5518e-01, 5.2828e-01, ..., 4.0676e-01,\n", - " 3.3873e-01, 3.6724e-01],\n", - " [4.0820e-01, 4.6187e-01, 4.7338e-01, ..., 3.8691e-01,\n", - " 3.6039e-01, 3.8022e-01]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 5.7852e-03, 0.0000e+00, ..., 7.4838e-03,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.0351e-02,\n", - " 0.0000e+00, 2.6720e-04],\n", - " [9.4807e-04, 0.0000e+00, 0.0000e+00, ..., 7.9551e-03,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [2.0326e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 1.0801e-02, 0.0000e+00],\n", - " [1.8470e-01, 0.0000e+00, 0.0000e+00, ..., 5.0584e-02,\n", - " 9.4758e-02, 5.9146e-02]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[3.8708e-01, 2.8022e-01, 3.5893e-01, ..., 1.6595e-01,\n", - " 1.6031e-01, 2.1136e-01],\n", - " [1.5595e-01, 3.0544e-01, 2.4666e-01, ..., 2.2675e-01,\n", - " 2.5765e-01, 1.9682e-01],\n", - " [2.9518e-01, 4.1210e-01, 2.0063e-01, ..., 1.7595e-01,\n", - " 2.2537e-01, 2.2214e-01],\n", - " ...,\n", - " [2.4745e-01, 2.6259e-01, 3.8654e-01, ..., 2.3620e-01,\n", - " 2.3157e-01, 1.8514e-01],\n", - " [2.5715e-01, 2.9593e-01, 4.7745e-01, ..., 2.3546e-01,\n", - " 2.5073e-01, 2.0976e-01],\n", - " [1.2015e+00, 8.4644e-01, 7.3386e-01, ..., 1.0252e+00,\n", - " 9.5310e-01, 1.0013e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[4.5013e-01, 4.7484e-01, 4.0540e-01, ..., 1.9346e-01,\n", - " 1.7826e-01, 1.4777e-01],\n", - " [4.7546e-01, 4.8187e-01, 3.6760e-01, ..., 2.7809e-01,\n", - " 3.2997e-01, 3.2337e-01],\n", - " [4.6160e-01, 4.0050e-01, 3.9061e-01, ..., 3.6613e-01,\n", - " 3.5243e-01, 2.9739e-01],\n", - " ...,\n", - " [5.5148e-01, 5.1018e-01, 4.0132e-01, ..., 3.8948e-01,\n", - " 3.5737e-01, 3.3088e-01],\n", - " [4.1973e-01, 4.5475e-01, 4.5320e-01, ..., 3.8343e-01,\n", - " 4.0126e-01, 3.6181e-01],\n", - " [3.4280e-01, 3.1606e-01, 4.4701e-01, ..., 2.1665e-01,\n", - " 2.3985e-01, 2.3903e-01]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.1783e-02, 0.0000e+00, 1.5805e-02, ..., 0.0000e+00,\n", - " 2.2508e-02, 0.0000e+00],\n", - " [4.3234e-02, 7.7864e-02, 0.0000e+00, ..., 1.6347e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.2092e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.3563e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[0.0000e+00, 2.5187e-01, 2.4979e-01, ..., 2.4775e-01,\n", - " 2.2354e-01, 1.9149e-01],\n", - " [1.6541e-01, 1.9586e-01, 1.9813e-01, ..., 2.7344e-01,\n", - " 2.0928e-01, 2.6150e-01],\n", - " [1.0495e-01, 6.3299e-02, 3.3844e-01, ..., 2.5138e-01,\n", - " 1.2470e-01, 2.3927e-01],\n", - " ...,\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.1428e-01, 4.5667e-01, 4.6821e-01, ..., 3.2058e-01,\n", - " 3.3579e-01, 3.9013e-01],\n", - " [1.0441e-01, 4.5739e-01, 4.6107e-01, ..., 3.8468e-01,\n", - " 3.8291e-01, 3.6686e-01],\n", - " [1.9868e-01, 3.5520e-01, 4.4313e-01, ..., 4.0679e-01,\n", - " 3.8068e-01, 3.0646e-01],\n", - " ...,\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.4654e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 3.3902e-02],\n", - " [0.0000e+00, 0.0000e+00, 1.8307e-02, ..., 5.1669e-02,\n", - " 9.4838e-03, 7.4535e-02],\n", - " [9.9215e-02, 0.0000e+00, 1.5872e-02, ..., 1.6203e-02,\n", - " 5.1401e-02, 1.9239e-03],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[4.0034e-01, 2.5306e-01, 2.0218e-01, ..., 9.8162e-02,\n", - " 7.0643e-02, 4.9741e-02],\n", - " [1.2568e-01, 2.1031e-01, 1.1182e-01, ..., 4.2781e-02,\n", - " 1.1969e-01, 1.2005e-01],\n", - " [2.8787e-01, 2.4031e-01, 2.2566e-01, ..., 0.0000e+00,\n", - " 6.4181e-02, 5.8730e-02],\n", - " ...,\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.8405e-01, 3.0990e-01, 3.7156e-01, ..., 1.8125e-01,\n", - " 1.5051e-01, 1.9620e-01],\n", - " [4.7286e-01, 4.0529e-01, 3.9718e-01, ..., 2.4710e-01,\n", - " 4.5657e-02, 1.1501e-01],\n", - " [3.2621e-01, 3.0073e-01, 3.0477e-01, ..., 2.3529e-01,\n", - " 2.1357e-01, 1.6986e-01],\n", - " ...,\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.3438e-02, 1.2378e-03, 5.2972e-02, ..., 7.2712e-02,\n", - " 8.6563e-02, 1.4494e-01],\n", - " [1.1043e-01, 6.1431e-02, 6.3630e-02, ..., 8.1278e-02,\n", - " 6.2590e-02, 8.3154e-02],\n", - " [1.7677e-02, 2.0111e-03, 7.8750e-02, ..., 6.9633e-02,\n", - " 8.9799e-02, 5.3263e-02],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.0034e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5627e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [5.1447e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 4.3641e-03],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[2.5142e-01, 4.5964e-01, 3.7346e-01, ..., 4.7631e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.9760e-01, 2.6627e-01, 1.1191e-01, ..., 3.0450e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.6341e-01, 3.2938e-01, 2.5690e-01, ..., 5.5694e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 2.2189e-02, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 2.8490e-02],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.5810e-01, 6.3017e-01, 3.7038e-01, ..., 1.8704e-01,\n", - " 8.2694e-02, 9.9127e-02],\n", - " [1.7293e-01, 5.0679e-01, 4.0739e-01, ..., 1.6006e-01,\n", - " 1.1725e-01, 9.9405e-02],\n", - " [2.4175e-01, 4.1616e-01, 4.1257e-01, ..., 1.3520e-01,\n", - " 7.9126e-02, 1.2846e-01],\n", - " ...,\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00]]]], device='cuda:0',\n", - " grad_fn=)\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", - "\n", - "x = xs.unsqueeze(1)\n", - "x = model.encoder.embed.conv(x)\n", - "print(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "9a9478ad", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[-0.03426375 0.14291267 -0.06718873 ... 0.09064753 0.01809387\n", - " -0.0434088 ]\n", - " [-0.05007839 0.11054724 -0.10399298 ... 0.11457238 0.04244684\n", - " -0.01249714]\n", - " [-0.10695291 0.16910909 -0.08352133 ... 0.07710276 0.01168563\n", - " -0.03584499]\n", - " ...\n", - " [-0.06060536 0.14455931 -0.05470302 ... 0.05364908 0.03033342\n", - " -0.02610814]\n", - " [-0.08505894 0.13611752 -0.11132983 ... 0.13079923 0.01580139\n", - " -0.02281028]\n", - " [-0.10604677 0.14714901 -0.10885533 ... 0.08543444 0.03719445\n", - " -0.04634233]]\n", - "\n", - " [[-0.12392755 0.14486063 -0.05674079 ... 0.02573164 0.03128851\n", - " 0.00545091]\n", - " [-0.04775286 0.08473608 -0.08507854 ... 0.04573154 0.04240163\n", - " 0.01053247]\n", - " [-0.05940291 0.10023535 -0.0814373 ... 0.035965 0.01673085\n", - " 0.02089563]\n", - " ...\n", - " [-0.09222981 0.15823206 -0.07700447 ... 0.08122957 0.03136991\n", - " -0.00646474]\n", - " [-0.07331756 0.14482647 -0.07838815 ... 0.1086944 0.01356864\n", - " -0.02777974]\n", - " [-0.07937264 0.20143102 -0.05544947 ... 0.10287814 0.00608235\n", - " -0.0479918 ]]\n", - "\n", - " [[-0.03670349 0.0893159 -0.08718812 ... 0.0131405 0.00642052\n", - " 0.00573716]\n", - " [ 0.01089254 0.11146393 -0.10263617 ... 0.05070438 0.01960694\n", - " 0.03521532]\n", - " [-0.0218228 0.11443964 -0.06678198 ... 0.04327708 0.00861394\n", - " 0.02871092]\n", - " ...\n", - " [-0.06792898 0.14376275 -0.07899005 ... 0.11248926 0.03208683\n", - " -0.0326424 ]\n", - " [-0.07884051 0.17024788 -0.08583611 ... 0.09028331 0.03588808\n", - " -0.0207509 ]\n", - " [-0.13792302 0.27163863 -0.23930418 ... 0.13391261 0.0752104\n", - " -0.08621951]]\n", - "\n", - " ...\n", - "\n", - " [[-0.02446348 0.11595841 -0.03591986 ... 0.0628897 0.02895011\n", - " -0.06532725]\n", - " [-0.05378424 0.1260737 -0.09023033 ... 0.09078894 0.01035743\n", - " 0.03701983]\n", - " [-0.04566649 0.14275314 -0.0668687 ... 0.09890588 -0.00612222\n", - " 0.03439377]\n", - " ...\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]]\n", - "\n", - " [[-0.01012144 0.03909408 -0.07077143 ... 0.00452683 -0.01377654\n", - " 0.02897627]\n", - " [-0.00519154 0.03594019 -0.06831125 ... 0.05693541 -0.00406374\n", - " 0.0456164 ]\n", - " [-0.01762631 0.00500899 -0.05886075 ... 0.02112178 -0.00729015\n", - " 0.02782153]\n", - " ...\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]]\n", - "\n", - " [[-0.03411558 -0.04318277 -0.08497842 ... -0.04886402 0.04296734\n", - " 0.06151697]\n", - " [ 0.00263296 -0.06913657 -0.08993219 ... -0.00149064 0.05696633\n", - " 0.03304394]\n", - " [-0.01818341 -0.0117864 -0.09679577 ... -0.00870231 0.00362198\n", - " 0.01916483]\n", - " ...\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]]]\n", - "torch.Size([16, 51, 256])\n" - ] - } - ], - "source": [ - "b, c, t, f = x.size()\n", - "x = model.encoder.embed.out(x.transpose(1, 2).contiguous().view(b, t, c * f))\n", - "print(x.cpu().detach().numpy())\n", - "print(x.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "fd69003f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n", - " -0.6945408 ]\n", - " [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n", - " -0.1999542 ]\n", - " [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n", - " -0.5735198 ]\n", - " ...\n", - " [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n", - " -0.41773027]\n", - " [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n", - " -0.36496443]\n", - " [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n", - " -0.74147725]]\n", - "\n", - " [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n", - " 0.08721463]\n", - " [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n", - " 0.16851945]\n", - " [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n", - " 0.33433008]\n", - " ...\n", - " [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n", - " -0.10343577]\n", - " [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n", - " -0.44447583]\n", - " [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n", - " -0.7678688 ]]\n", - "\n", - " [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n", - " 0.09179455]\n", - " [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n", - " 0.56344515]\n", - " [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n", - " 0.45937473]\n", - " ...\n", - " [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n", - " -0.52227837]\n", - " [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n", - " -0.33201432]\n", - " [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n", - " -1.3795122 ]]\n", - "\n", - " ...\n", - "\n", - " [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n", - " -1.045236 ]\n", - " [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n", - " 0.5923172 ]\n", - " [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n", - " 0.55030036]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n", - " 0.46362036]\n", - " [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n", - " 0.72986233]\n", - " [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n", - " 0.44514441]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n", - " 0.9842716 ]\n", - " [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n", - " 0.52870303]\n", - " [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n", - " 0.30663735]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]]\n" - ] - } - ], - "source": [ - "x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n", - "print(x.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "8ed88489", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.float32\n", - "[[[ 0.0000000e+00 1.0000000e+00 0.0000000e+00 ... 1.0000000e+00\n", - " 0.0000000e+00 1.0000000e+00]\n", - " [ 8.4147096e-01 5.4030234e-01 8.0196178e-01 ... 1.0000000e+00\n", - " 1.0746076e-04 1.0000000e+00]\n", - " [ 9.0929741e-01 -4.1614684e-01 9.5814437e-01 ... 1.0000000e+00\n", - " 2.1492151e-04 1.0000000e+00]\n", - " ...\n", - " [-7.6825464e-01 -6.4014435e-01 6.3279724e-01 ... 9.9998462e-01\n", - " 5.1580933e-03 9.9998671e-01]\n", - " [-9.5375264e-01 3.0059254e-01 9.9899054e-01 ... 9.9998397e-01\n", - " 5.2655530e-03 9.9998611e-01]\n", - " [-2.6237485e-01 9.6496606e-01 5.6074661e-01 ... 9.9998331e-01\n", - " 5.3730118e-03 9.9998558e-01]]]\n" - ] - } - ], - "source": [ - "print(pos_emb.dtype)\n", - "print(pos_emb.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "id": "5e277881", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([16, 51, 256])\n" - ] - }, - { - "ename": "NameError", - "evalue": "name 'mask' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0mpos_emb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpos_emb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 144\u001b[0m \u001b[0mx_att\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m )\n", - "\u001b[0;31mNameError\u001b[0m: name 'mask' is not defined" - ] - } - ], - "source": [ - "def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,\n", - " use_dynamic_chunk: bool,\n", - " use_dynamic_left_chunk: bool,\n", - " decoding_chunk_size: int, static_chunk_size: int,\n", - " num_decoding_left_chunks: int):\n", - " \"\"\" Apply optional mask for encoder.\n", - " Args:\n", - " xs (torch.Tensor): padded input, (B, L, D), L for max length\n", - " mask (torch.Tensor): mask for xs, (B, 1, L)\n", - " use_dynamic_chunk (bool): whether to use dynamic chunk or not\n", - " use_dynamic_left_chunk (bool): whether to use dynamic left chunk for\n", - " training.\n", - " decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's\n", - " 0: default for training, use random dynamic chunk.\n", - " <0: for decoding, use full chunk.\n", - " >0: for decoding, use fixed chunk size as set.\n", - " static_chunk_size (int): chunk size for static chunk training/decoding\n", - " if it's greater than 0, if use_dynamic_chunk is true,\n", - " this parameter will be ignored\n", - " num_decoding_left_chunks: number of left chunks, this is for decoding,\n", - " the chunk size is decoding_chunk_size.\n", - " >=0: use num_decoding_left_chunks\n", - " <0: use all left chunks\n", - " Returns:\n", - " torch.Tensor: chunk mask of the input xs.\n", - " \"\"\"\n", - " # Whether to use chunk mask or not\n", - " if use_dynamic_chunk:\n", - " max_len = xs.size(1)\n", - " if decoding_chunk_size < 0:\n", - " chunk_size = max_len\n", - " num_left_chunks = -1\n", - " elif decoding_chunk_size > 0:\n", - " chunk_size = decoding_chunk_size\n", - " num_left_chunks = num_decoding_left_chunks\n", - " else:\n", - " # chunk size is either [1, 25] or full context(max_len).\n", - " # Since we use 4 times subsampling and allow up to 1s(100 frames)\n", - " # delay, the maximum frame is 100 / 4 = 25.\n", - " chunk_size = torch.randint(1, max_len, (1, )).item()\n", - " num_left_chunks = -1\n", - " if chunk_size > max_len // 2:\n", - " chunk_size = max_len\n", - " else:\n", - " chunk_size = chunk_size % 25 + 1\n", - " if use_dynamic_left_chunk:\n", - " max_left_chunks = (max_len - 1) // chunk_size\n", - " num_left_chunks = torch.randint(0, max_left_chunks,\n", - " (1, )).item()\n", - " chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,\n", - " num_left_chunks,\n", - " xs.device) # (L, L)\n", - " chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n", - " chunk_masks = masks & chunk_masks # (B, L, L)\n", - " elif static_chunk_size > 0:\n", - " num_left_chunks = num_decoding_left_chunks\n", - " chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,\n", - " num_left_chunks,\n", - " xs.device) # (L, L)\n", - " chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n", - " chunk_masks = masks & chunk_masks # (B, L, L)\n", - " else:\n", - " chunk_masks = masks\n", - " return chunk_masks\n", - "\n", - "from wenet.utils.mask import make_pad_mask\n", - "\n", - "\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1)\n", - "xs = model.encoder.global_cmvn(feat)\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n", - "\n", - "mask_pad = masks\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "use_dynamic_left_chunk=-1\n", - "use_dynamic_chunk=False\n", - "static_chunk_size=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, \n", - " masks, \n", - " use_dynamic_chunk,\n", - " use_dynamic_left_chunk,\n", - " decoding_chunk_size, \n", - " static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_embed', \n", - " embed_out=xs.cpu().detach().numpy(), \n", - " pos_emb=pos_emb.cpu().detach().numpy(),\n", - " chunk_masks=chunk_masks.cpu().detach().numpy(),\n", - " mask_pad=mask_pad.cpu().detach().numpy())\n", - "\n", - "model.eval()\n", - "# print(chunk_masks)\n", - "print(xs.shape)\n", - "for layer in model.encoder.encoders:\n", - " #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " #np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0', enc_0=xs.cpu().detach().numpy())\n", - " \n", - " x = xs\n", - " residual = x\n", - " x_norm = layer.norm_ff_macaron(x)\n", - " !rm /workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff.npz\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff', \n", - " norm_ff=x_norm.cpu().detach().numpy(),\n", - " xs=xs.cpu().detach().numpy())\n", - " #print(x.cpu().detach().numpy())\n", - " for p in layer.norm_ff_macaron.parameters():\n", - " #print(p, p.sum())\n", - " pass\n", - " \n", - " x = residual + layer.ff_scale * layer.feed_forward_macaron(x_norm)\n", - " \n", - " ps = []\n", - " for n, p in layer.feed_forward_macaron.state_dict().items():\n", - " #print(n, p.cpu().data.numpy())\n", - " ps.append(p.cpu().data.numpy())\n", - " pass\n", - "\n", - " ff_l_x = layer.feed_forward_macaron.w_1(x_norm)\n", - " ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n", - " ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_ff_out', \n", - " norm_ff=x_norm.cpu().detach().numpy(),\n", - " ff_out=x.cpu().detach().numpy(),\n", - " ff_l_x = ff_l_x.cpu().detach().numpy(),\n", - " ff_l_a_x=ff_l_a_x.cpu().detach().numpy(),\n", - " ff_l_a_l_x=ff_l_a_l_x.cpu().detach().numpy(),\n", - " ps=ps,\n", - " )\n", - " \n", - " \n", - " residual = x\n", - " x = layer.norm_mha(x)\n", - " x_q = x\n", - " \n", - " x_att = layer.self_attn(x_q, x, x, pos_emb, masks)\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_selattn_out', \n", - " x_q=x_q.cpu().detach().numpy(),\n", - " x=x.cpu().detach().numpy(),\n", - " pos_emb = pos_emb.cpu().detach().numpy(),\n", - " mask=mask.cpu().detach().numpy(),\n", - " x_att=x_att.cpu().detach().numpy(),\n", - " )\n", - " \n", - " break\n", - "#print(xs.cpu().detach().numpy())\n", - "\n", - "\n", - "i = 0\n", - "for layer in model.encoder.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " i += 1\n", - " if i == 2:\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_2', enc_2=xs.cpu().detach().numpy())\n", - " \n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_all', enc_all=xs.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c43fd4f1", - "metadata": {}, - "outputs": [], - "source": [ - "out, mask = model.encoder(feat, feat_len)\n", - "#print(out.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0e73db22", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8f506114", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/README.md b/README.md index de24abe2f..71bc63638 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ -[中文版](README_cn.md) - # PaddlePaddle Speech to Any toolkit ![License](https://img.shields.io/badge/license-Apache%202-red.svg) @@ -11,31 +9,29 @@ ## Features - See [feature list](doc/src/feature_list.md) for more information. + See [feature list](docs/src/feature_list.md) for more information. ## Setup All tested under: * Ubuntu 16.04 * python>=3.7 -* paddlepaddle>=2.1.2 +* paddlepaddle>=2.2.0rc -Please see [install](doc/src/install.md). +Please see [install](docs/src/install.md). ## Getting Started -Please see [Getting Started](doc/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md). +Please see [Getting Started](docs/src/getting_started.md) and [tiny egs](examples/tiny/s0/README.md). ## More Information -* [Data Prepration](doc/src/data_preparation.md) -* [Data Augmentation](doc/src/augmentation.md) -* [Ngram LM](doc/src/ngram_lm.md) -* [Server Demo](doc/src/server.md) -* [Benchmark](doc/src/benchmark.md) -* [Relased Model](doc/src/released_model.md) -* [FAQ](doc/src/faq.md) +* [Data Prepration](docs/src/data_preparation.md) +* [Data Augmentation](docs/src/augmentation.md) +* [Ngram LM](docs/src/ngram_lm.md) +* [Benchmark](docs/src/benchmark.md) +* [Relased Model](docs/src/released_model.md) ## Questions and Help @@ -45,8 +41,8 @@ You are welcome to submit questions in [Github Discussions](https://github.com/P ## License -DeepASR is provided under the [Apache-2.0 License](./LICENSE). +DeepSpeech is provided under the [Apache-2.0 License](./LICENSE). ## Acknowledgement -We depends on many open source repos. See [References](doc/src/reference.md) for more information. +We depends on many open source repos. See [References](docs/src/reference.md) for more information. diff --git a/README_cn.md b/README_cn.md deleted file mode 100644 index 4b9273625..000000000 --- a/README_cn.md +++ /dev/null @@ -1,51 +0,0 @@ -[English](README.md) - -# PaddlePaddle Speech to Any toolkit - -![License](https://img.shields.io/badge/license-Apache%202-red.svg) -![python version](https://img.shields.io/badge/python-3.7+-orange.svg) -![support os](https://img.shields.io/badge/os-linux-yellow.svg) - -*DeepSpeech*是一个采用[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)平台的端到端自动语音识别引擎的开源项目, -我们的愿景是为语音识别在工业应用和学术研究上,提供易于使用、高效、小型化和可扩展的工具,包括训练,推理,以及 部署。 - -## 特性 - - 参看 [特性列表](doc/src/feature_list.md)。 - - -## 安装 - -在以下环境测试验证过: - -* Ubuntu 16.04 -* python>=3.7 -* paddlepaddle>=2.1.2 - -参看 [安装](doc/src/install.md)。 - -## 开始 - -请查看 [开始](doc/src/getting_started.md) 和 [tiny egs](examples/tiny/s0/README.md)。 - -## 更多信息 - -* [数据处理](doc/src/data_preparation.md) -* [数据增强](doc/src/augmentation.md) -* [语言模型](doc/src/ngram_lm.md) -* [服务部署](doc/src/server.md) -* [Benchmark](doc/src/benchmark.md) -* [Relased Model](doc/src/released_model.md) -* [FAQ](doc/src/faq.md) - -## 问题和帮助 - -欢迎您在[Github讨论](https://github.com/PaddlePaddle/DeepSpeech/discussions)提交问题,[Github问题](https://github.com/PaddlePaddle/models/issues)中反馈bug。也欢迎您为这个项目做出贡献。 - -## License - -DeepASR 遵循[Apache-2.0开源协议](./LICENSE)。 - -## 感谢 - -开发中参考一些优秀的仓库,详情参见 [References](doc/src/reference.md)。 diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index d85a3dde7..5505ecbf0 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -80,23 +80,23 @@ def convert_dtype_to_string(tensor_dtype): if not hasattr(paddle, 'softmax'): - logger.warn("register user softmax to paddle, remove this when fixed!") + logger.debug("register user softmax to paddle, remove this when fixed!") setattr(paddle, 'softmax', paddle.nn.functional.softmax) if not hasattr(paddle, 'log_softmax'): - logger.warn("register user log_softmax to paddle, remove this when fixed!") + logger.debug("register user log_softmax to paddle, remove this when fixed!") setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) if not hasattr(paddle, 'sigmoid'): - logger.warn("register user sigmoid to paddle, remove this when fixed!") + logger.debug("register user sigmoid to paddle, remove this when fixed!") setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) if not hasattr(paddle, 'log_sigmoid'): - logger.warn("register user log_sigmoid to paddle, remove this when fixed!") + logger.debug("register user log_sigmoid to paddle, remove this when fixed!") setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) if not hasattr(paddle, 'relu'): - logger.warn("register user relu to paddle, remove this when fixed!") + logger.debug("register user relu to paddle, remove this when fixed!") setattr(paddle, 'relu', paddle.nn.functional.relu) @@ -105,7 +105,7 @@ def cat(xs, dim=0): if not hasattr(paddle, 'cat'): - logger.warn( + logger.debug( "override cat of paddle if exists or register, remove this when fixed!") paddle.cat = cat @@ -116,7 +116,7 @@ def item(x: paddle.Tensor): if not hasattr(paddle.Tensor, 'item'): - logger.warn( + logger.debug( "override item of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.item = item @@ -127,13 +127,13 @@ def func_long(x: paddle.Tensor): if not hasattr(paddle.Tensor, 'long'): - logger.warn( + logger.debug( "override long of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.long = func_long if not hasattr(paddle.Tensor, 'numel'): - logger.warn( + logger.debug( "override numel of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.numel = paddle.numel @@ -147,7 +147,7 @@ def new_full(x: paddle.Tensor, if not hasattr(paddle.Tensor, 'new_full'): - logger.warn( + logger.debug( "override new_full of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.new_full = new_full @@ -162,13 +162,13 @@ def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'eq'): - logger.warn( + logger.debug( "override eq of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.eq = eq if not hasattr(paddle, 'eq'): - logger.warn( + logger.debug( "override eq of paddle if exists or register, remove this when fixed!") paddle.eq = eq @@ -178,7 +178,7 @@ def contiguous(xs: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'contiguous'): - logger.warn( + logger.debug( "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" ) paddle.Tensor.contiguous = contiguous @@ -195,7 +195,7 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: #`to_static` do not process `size` property, maybe some `paddle` api dependent on it. -logger.warn( +logger.debug( "override size of paddle.Tensor " "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" ) @@ -207,7 +207,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'view'): - logger.warn("register user view to paddle.Tensor, remove this when fixed!") + logger.debug("register user view to paddle.Tensor, remove this when fixed!") paddle.Tensor.view = view @@ -216,7 +216,7 @@ def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'view_as'): - logger.warn( + logger.debug( "register user view_as to paddle.Tensor, remove this when fixed!") paddle.Tensor.view_as = view_as @@ -242,7 +242,7 @@ def masked_fill(xs: paddle.Tensor, if not hasattr(paddle.Tensor, 'masked_fill'): - logger.warn( + logger.debug( "register user masked_fill to paddle.Tensor, remove this when fixed!") paddle.Tensor.masked_fill = masked_fill @@ -260,7 +260,7 @@ def masked_fill_(xs: paddle.Tensor, if not hasattr(paddle.Tensor, 'masked_fill_'): - logger.warn( + logger.debug( "register user masked_fill_ to paddle.Tensor, remove this when fixed!") paddle.Tensor.masked_fill_ = masked_fill_ @@ -272,7 +272,8 @@ def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'fill_'): - logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") + logger.debug( + "register user fill_ to paddle.Tensor, remove this when fixed!") paddle.Tensor.fill_ = fill_ @@ -281,22 +282,22 @@ def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'repeat'): - logger.warn( + logger.debug( "register user repeat to paddle.Tensor, remove this when fixed!") paddle.Tensor.repeat = repeat if not hasattr(paddle.Tensor, 'softmax'): - logger.warn( + logger.debug( "register user softmax to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) if not hasattr(paddle.Tensor, 'sigmoid'): - logger.warn( + logger.debug( "register user sigmoid to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) if not hasattr(paddle.Tensor, 'relu'): - logger.warn("register user relu to paddle.Tensor, remove this when fixed!") + logger.debug("register user relu to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) @@ -305,7 +306,7 @@ def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'type_as'): - logger.warn( + logger.debug( "register user type_as to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'type_as', type_as) @@ -321,7 +322,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'to'): - logger.warn("register user to to paddle.Tensor, remove this when fixed!") + logger.debug("register user to to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'to', to) @@ -330,7 +331,8 @@ def func_float(x: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'float'): - logger.warn("register user float to paddle.Tensor, remove this when fixed!") + logger.debug( + "register user float to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'float', func_float) @@ -339,7 +341,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'int'): - logger.warn("register user int to paddle.Tensor, remove this when fixed!") + logger.debug("register user int to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'int', func_int) @@ -348,23 +350,6 @@ def tolist(x: paddle.Tensor) -> List[Any]: if not hasattr(paddle.Tensor, 'tolist'): - logger.warn( + logger.debug( "register user tolist to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'tolist', tolist) - - -########### hcak paddle.nn ############# -class GLU(nn.Layer): - """Gated Linear Units (GLU) Layer""" - - def __init__(self, dim: int=-1): - super().__init__() - self.dim = dim - - def forward(self, xs): - return F.glu(xs, axis=self.dim) - - -if not hasattr(paddle.nn, 'GLU'): - logger.warn("register user GLU to paddle.nn, remove this when fixed!") - setattr(paddle.nn, 'GLU', GLU) diff --git a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp index 4dcc7c899..fcb1f7642 100644 --- a/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_beam_search_decoder.cpp @@ -35,7 +35,8 @@ std::vector> ctc_beam_search_decoder( size_t beam_size, double cutoff_prob, size_t cutoff_top_n, - Scorer *ext_scorer) { + Scorer *ext_scorer, + size_t blank_id) { // dimension check size_t num_time_steps = probs_seq.size(); for (size_t i = 0; i < num_time_steps; ++i) { @@ -48,7 +49,7 @@ std::vector> ctc_beam_search_decoder( // assign blank id // size_t blank_id = vocabulary.size(); - size_t blank_id = 0; + // size_t blank_id = 0; // assign space id auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); @@ -57,7 +58,6 @@ std::vector> ctc_beam_search_decoder( if ((size_t)space_id >= vocabulary.size()) { space_id = -2; } - // init prefixes' root PathTrie root; root.score = root.log_prob_b_prev = 0.0; @@ -218,7 +218,8 @@ ctc_beam_search_decoder_batch( size_t num_processes, double cutoff_prob, size_t cutoff_top_n, - Scorer *ext_scorer) { + Scorer *ext_scorer, + size_t blank_id) { VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); // thread pool ThreadPool pool(num_processes); @@ -234,7 +235,8 @@ ctc_beam_search_decoder_batch( beam_size, cutoff_prob, cutoff_top_n, - ext_scorer)); + ext_scorer, + blank_id)); } // get decoding results diff --git a/deepspeech/decoders/swig/ctc_beam_search_decoder.h b/deepspeech/decoders/swig/ctc_beam_search_decoder.h index c31510da3..eaba9da8c 100644 --- a/deepspeech/decoders/swig/ctc_beam_search_decoder.h +++ b/deepspeech/decoders/swig/ctc_beam_search_decoder.h @@ -43,7 +43,8 @@ std::vector> ctc_beam_search_decoder( size_t beam_size, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); /* CTC Beam Search Decoder for batch data @@ -70,6 +71,7 @@ ctc_beam_search_decoder_batch( size_t num_processes, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr); + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp index 1c735c424..18008cced 100644 --- a/deepspeech/decoders/swig/ctc_greedy_decoder.cpp +++ b/deepspeech/decoders/swig/ctc_greedy_decoder.cpp @@ -17,17 +17,18 @@ std::string ctc_greedy_decoder( const std::vector> &probs_seq, - const std::vector &vocabulary) { + const std::vector &vocabulary, + size_t blank_id) { // dimension check size_t num_time_steps = probs_seq.size(); for (size_t i = 0; i < num_time_steps; ++i) { VALID_CHECK_EQ(probs_seq[i].size(), - vocabulary.size() + 1, + vocabulary.size(), "The shape of probs_seq does not match with " "the shape of the vocabulary"); } - size_t blank_id = vocabulary.size(); + // size_t blank_id = vocabulary.size(); std::vector max_idx_vec(num_time_steps, 0); std::vector idx_vec; diff --git a/deepspeech/decoders/swig/ctc_greedy_decoder.h b/deepspeech/decoders/swig/ctc_greedy_decoder.h index 5e8c5c251..dd1b33315 100644 --- a/deepspeech/decoders/swig/ctc_greedy_decoder.h +++ b/deepspeech/decoders/swig/ctc_greedy_decoder.h @@ -29,6 +29,7 @@ */ std::string ctc_greedy_decoder( const std::vector>& probs_seq, - const std::vector& vocabulary); + const std::vector& vocabulary, + size_t blank_id); #endif // CTC_GREEDY_DECODER_H diff --git a/deepspeech/decoders/swig/setup.py b/deepspeech/decoders/swig/setup.py index 8fb792962..c089f96cd 100644 --- a/deepspeech/decoders/swig/setup.py +++ b/deepspeech/decoders/swig/setup.py @@ -85,9 +85,8 @@ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') # yapf: disable FILES = [ - fn for fn in FILES - if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( - 'unittest.cc')) + fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc') + or fn.endswith('unittest.cc')) ] # yapf: enable diff --git a/deepspeech/decoders/swig_wrapper.py b/deepspeech/decoders/swig_wrapper.py index 3ffdb9c74..d883d430c 100644 --- a/deepspeech/decoders/swig_wrapper.py +++ b/deepspeech/decoders/swig_wrapper.py @@ -32,7 +32,7 @@ class Scorer(swig_decoders.Scorer): swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) -def ctc_greedy_decoder(probs_seq, vocabulary): +def ctc_greedy_decoder(probs_seq, vocabulary, blank_id): """Wrapper for ctc best path decoder in swig. :param probs_seq: 2-D list of probability distributions over each time @@ -44,7 +44,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary): :return: Decoding result string. :rtype: str """ - result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) + result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary, + blank_id) return result @@ -53,7 +54,8 @@ def ctc_beam_search_decoder(probs_seq, beam_size, cutoff_prob=1.0, cutoff_top_n=40, - ext_scoring_func=None): + ext_scoring_func=None, + blank_id=0): """Wrapper for the CTC Beam Search Decoder. :param probs_seq: 2-D list of probability distributions over each time @@ -81,7 +83,7 @@ def ctc_beam_search_decoder(probs_seq, """ beam_results = swig_decoders.ctc_beam_search_decoder( probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n, - ext_scoring_func) + ext_scoring_func, blank_id) beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results] return beam_results @@ -92,7 +94,8 @@ def ctc_beam_search_decoder_batch(probs_split, num_processes, cutoff_prob=1.0, cutoff_top_n=40, - ext_scoring_func=None): + ext_scoring_func=None, + blank_id=0): """Wrapper for the batched CTC beam search decoder. :param probs_seq: 3-D list with each element as an instance of 2-D list @@ -125,7 +128,7 @@ def ctc_beam_search_decoder_batch(probs_split, batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch( probs_split, vocabulary, beam_size, num_processes, cutoff_prob, - cutoff_top_n, ext_scoring_func) + cutoff_top_n, ext_scoring_func, blank_id) batch_beam_results = [[(res[0], res[1]) for res in beam_results] for beam_results in batch_beam_results] return batch_beam_results diff --git a/deepspeech/exps/deepspeech2/bin/train.py b/deepspeech/exps/deepspeech2/bin/train.py index 69ff043a0..6740f288f 100644 --- a/deepspeech/exps/deepspeech2/bin/train.py +++ b/deepspeech/exps/deepspeech2/bin/train.py @@ -27,7 +27,7 @@ def main_sp(config, args): def main(config, args): - if args.device == "gpu" and args.nprocs > 1: + if args.nprocs > 0: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py deleted file mode 100644 index 94a9b6c47..000000000 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Beam search parameters tuning for DeepSpeech2 model.""" -import functools -import sys - -import numpy as np -from paddle.io import DataLoader - -from deepspeech.exps.deepspeech2.config import get_cfg_defaults -from deepspeech.io.collator import SpeechCollator -from deepspeech.io.dataset import ManifestDataset -from deepspeech.models.ds2 import DeepSpeech2Model -from deepspeech.training.cli import default_argument_parser -from deepspeech.utils import error_rate -from deepspeech.utils.utility import add_arguments -from deepspeech.utils.utility import print_arguments - - -def tune(config, args): - """Tune parameters alpha and beta incrementally.""" - if not args.num_alphas >= 0: - raise ValueError("num_alphas must be non-negative!") - if not args.num_betas >= 0: - raise ValueError("num_betas must be non-negative!") - config.defrost() - config.data.manfiest = config.data.dev_manifest - config.data.augmentation_config = "" - config.data.keep_transcription_text = True - dev_dataset = ManifestDataset.from_config(config) - - valid_loader = DataLoader( - dev_dataset, - batch_size=config.data.batch_size, - shuffle=False, - drop_last=False, - collate_fn=SpeechCollator(keep_transcription_text=True)) - - model = DeepSpeech2Model.from_pretrained(valid_loader, config, - args.checkpoint_path) - model.eval() - - # decoders only accept string encoded in utf-8 - vocab_list = valid_loader.dataset.vocab_list - errors_func = error_rate.char_errors if config.decoding.error_rate_type == 'cer' else error_rate.word_errors - - # create grid for search - cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) - cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas) - params_grid = [(alpha, beta) for alpha in cand_alphas - for beta in cand_betas] - - err_sum = [0.0 for i in range(len(params_grid))] - err_ave = [0.0 for i in range(len(params_grid))] - - num_ins, len_refs, cur_batch = 0, 0, 0 - # initialize external scorer - model.decoder.init_decode(args.alpha_from, args.beta_from, - config.decoding.lang_model_path, vocab_list, - config.decoding.decoding_method) - ## incremental tuning parameters over multiple batches - print("start tuning ...") - for infer_data in valid_loader(): - if (args.num_batches >= 0) and (cur_batch >= args.num_batches): - break - - def ordid2token(texts, texts_len): - """ ord() id to chr() chr """ - trans = [] - for text, n in zip(texts, texts_len): - n = n.numpy().item() - ids = text[:n] - trans.append(''.join([chr(i) for i in ids])) - return trans - - audio, audio_len, text, text_len = infer_data - target_transcripts = ordid2token(text, text_len) - num_ins += audio.shape[0] - - # model infer - eouts, eouts_len = model.encoder(audio, audio_len) - probs = model.decoder.softmax(eouts) - - # grid search - for index, (alpha, beta) in enumerate(params_grid): - print(f"tuneing: alpha={alpha} beta={beta}") - result_transcripts = model.decoder.decode_probs( - probs.numpy(), eouts_len, vocab_list, - config.decoding.decoding_method, - config.decoding.lang_model_path, alpha, beta, - config.decoding.beam_size, config.decoding.cutoff_prob, - config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch) - - for target, result in zip(target_transcripts, result_transcripts): - errors, len_ref = errors_func(target, result) - err_sum[index] += errors - - # accumulate the length of references of every batchπ - # in the first iteration - if args.alpha_from == alpha and args.beta_from == beta: - len_refs += len_ref - - err_ave[index] = err_sum[index] / len_refs - if index % 2 == 0: - sys.stdout.write('.') - sys.stdout.flush() - print("tuneing: one grid done!") - - # output on-line tuning result at the end of current batch - err_ave_min = min(err_ave) - min_index = err_ave.index(err_ave_min) - print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), " - " min [%s] = %f" % - (cur_batch, num_ins, "%.3f" % params_grid[min_index][0], - "%.3f" % params_grid[min_index][1], - config.decoding.error_rate_type, err_ave_min)) - cur_batch += 1 - - # output WER/CER at every (alpha, beta) - print("\nFinal %s:\n" % config.decoding.error_rate_type) - for index in range(len(params_grid)): - print("(alpha, beta) = (%s, %s), [%s] = %f" % - ("%.3f" % params_grid[index][0], "%.3f" % params_grid[index][1], - config.decoding.error_rate_type, err_ave[index])) - - err_ave_min = min(err_ave) - min_index = err_ave.index(err_ave_min) - print("\nFinish tuning on %d batches, final opt (alpha, beta) = (%s, %s)" % - (cur_batch, "%.3f" % params_grid[min_index][0], - "%.3f" % params_grid[min_index][1])) - - print("finish tuning") - - -def main(config, args): - tune(config, args) - - -if __name__ == "__main__": - parser = default_argument_parser() - add_arg = functools.partial(add_arguments, argparser=parser) - add_arg('num_batches', int, -1, "# of batches tuning on. " - "Default -1, on whole dev set.") - add_arg('num_alphas', int, 45, "# of alpha candidates for tuning.") - add_arg('num_betas', int, 8, "# of beta candidates for tuning.") - add_arg('alpha_from', float, 1.0, "Where alpha starts tuning from.") - add_arg('alpha_to', float, 3.2, "Where alpha ends tuning with.") - add_arg('beta_from', float, 0.1, "Where beta starts tuning from.") - add_arg('beta_to', float, 0.45, "Where beta ends tuning with.") - - add_arg('batch_size', int, 256, "# of samples per batch.") - add_arg('beam_size', int, 500, "Beam search width.") - add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.") - add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") - add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") - - args = parser.parse_args() - print_arguments(args, globals()) - - # https://yaml.org/type/float.html - config = get_cfg_defaults() - if args.config: - config.merge_from_file(args.config) - if args.opts: - config.merge_from_list(args.opts) - - config.data.batch_size = args.batch_size - config.decoding.beam_size = args.beam_size - config.decoding.num_proc_bsearch = args.num_proc_bsearch - config.decoding.cutoff_prob = args.cutoff_prob - config.decoding.cutoff_top_n = args.cutoff_top_n - - config.freeze() - print(config) - - if args.dump_config: - with open(args.dump_config, 'w') as f: - print(config, file=f) - - main(config, args) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index f3e3fcadf..79a676345 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -15,9 +15,11 @@ import os import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle from paddle import distributed as dist @@ -34,12 +36,14 @@ from deepspeech.models.ds2 import DeepSpeech2Model from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline from deepspeech.models.ds2_online import DeepSpeech2ModelOnline from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog +from deepspeech.training.reporter import report from deepspeech.training.trainer import Trainer from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools from deepspeech.utils.log import Autolog from deepspeech.utils.log import Log +from deepspeech.utils.utility import UpdateConfig logger = Log(__name__).getlog() @@ -65,29 +69,52 @@ class DeepSpeech2Trainer(Trainer): super().__init__(config, args) def train_batch(self, batch_index, batch_data, msg): + batch_size = self.config.collator.batch_size + accum_grad = self.config.training.accum_grad + start = time.time() + + # forward utt, audio, audio_len, text, text_len = batch_data loss = self.model(audio, audio_len, text, text_len) - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - self.optimizer.step() - self.optimizer.clear_grad() - iteration_time = time.time() - start - losses_np = { 'train_loss': float(loss), } - msg += "train time: {:>.3f}s, ".format(iteration_time) - msg += "batch size: {}, ".format(self.config.collator.batch_size) - msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in losses_np.items()) - logger.info(msg) + + # loss backward + if (batch_index + 1) % accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step + if (batch_index + 1) % accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + self.iteration += 1 + + iteration_time = time.time() - start + + for k, v in losses_np.items(): + report(k, v) + report("batch_size", batch_size) + report("accum", accum_grad) + report("step_cost", iteration_time) if dist.get_rank() == 0 and self.visualizer: for k, v in losses_np.items(): + # `step -1` since we update `step` after optimizer.step(). self.visualizer.add_scalar("train/{}".format(k), v, - self.iteration) - self.iteration += 1 + self.iteration - 1) @paddle.no_grad() def valid(self): @@ -124,10 +151,9 @@ class DeepSpeech2Trainer(Trainer): def setup_model(self): config = self.config.clone() - config.defrost() - config.model.feat_size = self.train_loader.collate_fn.feature_size - config.model.dict_size = self.train_loader.collate_fn.vocab_size - config.freeze() + with UpdateConfig(config): + config.model.feat_size = self.train_loader.collate_fn.feature_size + config.model.dict_size = self.train_loader.collate_fn.vocab_size if self.args.model_type == 'offline': model = DeepSpeech2Model.from_config(config.model) @@ -280,9 +306,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): len_refs += len_ref num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) + fout.write({"utt": utt, "ref": target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("Current error rate [%s] = %f" % (cfg.error_rate_type, error_rate_func(target, result))) @@ -325,7 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cfg = self.config error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): utts, audio, audio_len, texts, texts_len = batch metrics = self.compute_metrics(utts, audio, audio_len, texts, @@ -378,7 +405,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') self.setup_output_dir() self.setup_checkpointer() @@ -610,7 +637,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') self.setup_output_dir() diff --git a/deepspeech/exps/u2/bin/train.py b/deepspeech/exps/u2/bin/train.py index 9dd0041dd..17fb08a6c 100644 --- a/deepspeech/exps/u2/bin/train.py +++ b/deepspeech/exps/u2/bin/train.py @@ -22,6 +22,8 @@ from deepspeech.exps.u2.model import U2Trainer as Trainer from deepspeech.training.cli import default_argument_parser from deepspeech.utils.utility import print_arguments +# from deepspeech.exps.u2.trainer import U2Trainer as Trainer + def main_sp(config, args): exp = Trainer(config, args) @@ -30,7 +32,7 @@ def main_sp(config, args): def main(config, args): - if args.device == "gpu" and args.nprocs > 1: + if args.nprocs > 0: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 0662e38d9..5cb0962a7 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -17,9 +17,12 @@ import os import sys import time from collections import defaultdict +from collections import OrderedDict +from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle from paddle import distributed as dist @@ -32,7 +35,10 @@ from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.u2 import U2Model from deepspeech.training.optimizer import OptimizerFactory +from deepspeech.training.reporter import ObsScope +from deepspeech.training.reporter import report from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate @@ -41,6 +47,7 @@ from deepspeech.utils import mp_tools from deepspeech.utils import text_grid from deepspeech.utils import utility from deepspeech.utils.log import Log +from deepspeech.utils.utility import UpdateConfig logger = Log(__name__).getlog() @@ -79,21 +86,36 @@ class U2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() - utt, audio, audio_len, text, text_len = batch_data + # forward + utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + # When using cpu w/o DDP, model does not have `no_sync` + context = self.model.no_sync if self.parallel else nullcontext + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() @@ -102,14 +124,13 @@ class U2Trainer(Trainer): iteration_time = time.time() - start - if (batch_index + 1) % train_conf.log_interval == 0: - msg += "train time: {:>.3f}s, ".format(iteration_time) - msg += "batch size: {}, ".format(self.config.collator.batch_size) - msg += "accum: {}, ".format(train_conf.accum_grad) - msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in losses_np.items()) - logger.info(msg) + for k, v in losses_np.items(): + report(k, v) + report("batch_size", self.config.collator.batch_size) + report("accum", train_conf.accum_grad) + report("step_cost", iteration_time) + if (batch_index + 1) % train_conf.accum_grad == 0: if dist.get_rank() == 0 and self.visualizer: losses_np_v = losses_np.copy() losses_np_v.update({"lr": self.lr_scheduler()}) @@ -163,46 +184,58 @@ class U2Trainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init') - - self.lr_scheduler.step(self.iteration) - if self.parallel: - self.train_loader.batch_sampler.set_epoch(self.epoch) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train:" + observation = OrderedDict() + with ObsScope(observation): + report("Rank", dist.get_rank()) + report("epoch", self.epoch) + report('step', self.iteration) + report("lr", self.lr_scheduler()) + self.train_batch(batch_index, batch, msg) + self.after_train_batch() + report('iter', batch_index + 1) + report('total', len(self.train_loader)) + report('reader_cost', dataload_time) + observation['batch_cost'] = observation[ + 'reader_cost'] + observation['step_cost'] + observation['samples'] = observation['batch_size'] + observation['ips[sent./sec]'] = observation[ + 'batch_size'] / observation['batch_cost'] + for k, v in observation.items(): + msg += f" {k}: " + msg += f"{v:>.8f}" if isinstance(v, + float) else f"{v}" + msg += "," + if (batch_index + 1 + ) % self.config.training.log_interval == 0: + logger.info(msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) @@ -294,10 +327,11 @@ class U2Trainer(Trainer): def setup_model(self): config = self.config model_conf = config.model - model_conf.defrost() - model_conf.input_dim = self.train_loader.collate_fn.feature_size - model_conf.output_dim = self.train_loader.collate_fn.vocab_size - model_conf.freeze() + + with UpdateConfig(model_conf): + model_conf.input_dim = self.train_loader.collate_fn.feature_size + model_conf.output_dim = self.train_loader.collate_fn.vocab_size + model = U2Model.from_config(model_conf) if self.parallel: @@ -433,9 +467,10 @@ class U2Tester(U2Trainer): len_refs += len_ref num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) + fout.write({"utt": utt, "ref": target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("One example error rate [%s] = %f" % (cfg.error_rate_type, error_rate_func(target, result))) @@ -460,7 +495,7 @@ class U2Tester(U2Trainer): errors_sum, len_refs, num_ins = 0.0, 0, 0 num_frames = 0.0 num_time = 0.0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): metrics = self.compute_metrics(*batch, fout=fout) num_frames += metrics['num_frames'] @@ -540,7 +575,7 @@ class U2Tester(U2Trainer): # 1. Encoder encoder_out, encoder_mask = self.model._forward_encoder( feat, feats_length) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) + maxlen = encoder_out.shape[1] ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) @@ -548,26 +583,25 @@ class U2Tester(U2Trainer): ctc_probs = ctc_probs.squeeze(0) target = target.squeeze(0) alignment = ctc_utils.forced_align(ctc_probs, target) - logger.info("align ids", key[0], alignment) + logger.info(f"align ids: {key[0]} {alignment}") fout.write('{} {}\n'.format(key[0], alignment)) # 3. gen praat # segment alignment align_segs = text_grid.segment_alignment(alignment) - logger.info("align tokens", key[0], align_segs) + logger.info(f"align tokens: {key[0]}, {align_segs}") # IntervalTier, List["start end token\n"] subsample = utility.get_subsample(self.config) tierformat = text_grid.align_to_tierformat( align_segs, subsample, token_dict) # write tier - align_output_path = os.path.join( - os.path.dirname(self.args.result_file), "align") - tier_path = os.path.join(align_output_path, key[0] + ".tier") - with open(tier_path, 'w') as f: + align_output_path = Path(self.args.result_file).parent / "align" + align_output_path.mkdir(parents=True, exist_ok=True) + tier_path = align_output_path / (key[0] + ".tier") + with tier_path.open('w') as f: f.writelines(tierformat) # write textgrid - textgrid_path = os.path.join(align_output_path, - key[0] + ".TextGrid") + textgrid_path = align_output_path / (key[0] + ".TextGrid") second_per_frame = 1. / (1000. / stride_ms) # 25ms window, 10ms stride second_per_example = ( @@ -575,7 +609,7 @@ class U2Tester(U2Trainer): text_grid.generate_textgrid( maxtime=second_per_example, intervals=tierformat, - output=textgrid_path) + output=str(textgrid_path)) def run_align(self): self.resume_or_scratch() @@ -621,7 +655,7 @@ class U2Tester(U2Trainer): def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') self.setup_output_dir() self.setup_checkpointer() diff --git a/deepspeech/exps/u2/trainer.py b/deepspeech/exps/u2/trainer.py new file mode 100644 index 000000000..8e8634ac3 --- /dev/null +++ b/deepspeech/exps/u2/trainer.py @@ -0,0 +1,220 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains U2 model.""" +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader + +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.dataset import ManifestDataset +from deepspeech.io.sampler import SortagradBatchSampler +from deepspeech.io.sampler import SortagradDistributedBatchSampler +from deepspeech.models.u2 import U2Evaluator +from deepspeech.models.u2 import U2Model +from deepspeech.models.u2 import U2Updater +from deepspeech.training.extensions.snapshot import Snapshot +from deepspeech.training.extensions.visualizer import VisualDL +from deepspeech.training.optimizer import OptimizerFactory +from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer +from deepspeech.training.trainer import Trainer +from deepspeech.training.updaters.trainer import Trainer as NewTrainer +from deepspeech.utils import layer_tools +from deepspeech.utils.log import Log +from deepspeech.utils.utility import UpdateConfig + +logger = Log(__name__).getlog() + + +class U2Trainer(Trainer): + def __init__(self, config, args): + super().__init__(config, args) + + def setup_dataloader(self): + config = self.config.clone() + config.defrost() + config.collator.keep_transcription_text = False + + # train/valid dataset, return token ids + config.data.manifest = config.data.train_manifest + train_dataset = ManifestDataset.from_config(config) + + config.data.manifest = config.data.dev_manifest + dev_dataset = ManifestDataset.from_config(config) + + collate_fn_train = SpeechCollator.from_config(config) + + config.collator.augmentation_config = "" + collate_fn_dev = SpeechCollator.from_config(config) + + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.collator.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.collator.batch_size, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + self.train_loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers, ) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config.collator.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_dev) + + # test dataset, return raw text + config.data.manifest = config.data.test_manifest + # filter test examples, will cause less examples, but no mismatch with training + # and can use large batch size , save training time, so filter test egs now. + config.data.min_input_len = 0.0 # second + config.data.max_input_len = float('inf') # second + config.data.min_output_len = 0.0 # tokens + config.data.max_output_len = float('inf') # tokens + config.data.min_output_input_ratio = 0.00 + config.data.max_output_input_ratio = float('inf') + + test_dataset = ManifestDataset.from_config(config) + # return text ord id + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator.from_config(config)) + # return text token id + config.collator.keep_transcription_text = False + self.align_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator.from_config(config)) + logger.info("Setup train/valid/test/align Dataloader!") + + def setup_model(self): + config = self.config + model_conf = config.model + with UpdateConfig(model_conf): + model_conf.input_dim = self.train_loader.collate_fn.feature_size + model_conf.output_dim = self.train_loader.collate_fn.vocab_size + + model = U2Model.from_config(model_conf) + + if self.parallel: + model = paddle.DataParallel(model) + + model.train() + logger.info(f"{model}") + layer_tools.print_params(model, logger.info) + + train_config = config.training + optim_type = train_config.optim + optim_conf = train_config.optim_conf + scheduler_type = train_config.scheduler + scheduler_conf = train_config.scheduler_conf + + scheduler_args = { + "learning_rate": optim_conf.lr, + "verbose": False, + "warmup_steps": scheduler_conf.warmup_steps, + "gamma": scheduler_conf.lr_decay, + "d_model": model_conf.encoder_conf.output_size, + } + lr_scheduler = LRSchedulerFactory.from_args(scheduler_type, + scheduler_args) + + def optimizer_args( + config, + parameters, + lr_scheduler=None, ): + train_config = config.training + optim_type = train_config.optim + optim_conf = train_config.optim_conf + scheduler_type = train_config.scheduler + scheduler_conf = train_config.scheduler_conf + return { + "grad_clip": train_config.global_grad_clip, + "weight_decay": optim_conf.weight_decay, + "learning_rate": lr_scheduler + if lr_scheduler else optim_conf.lr, + "parameters": parameters, + "epsilon": 1e-9 if optim_type == 'noam' else None, + "beta1": 0.9 if optim_type == 'noam' else None, + "beat2": 0.98 if optim_type == 'noam' else None, + } + + optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler) + optimizer = OptimizerFactory.from_args(optim_type, optimzer_args) + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + logger.info("Setup model/optimizer/lr_scheduler!") + + def setup_updater(self): + output_dir = self.output_dir + config = self.config.training + + updater = U2Updater( + model=self.model, + optimizer=self.optimizer, + scheduler=self.lr_scheduler, + dataloader=self.train_loader, + output_dir=output_dir, + accum_grad=config.accum_grad) + + trainer = NewTrainer(updater, (config.n_epoch, 'epoch'), output_dir) + + evaluator = U2Evaluator(self.model, self.valid_loader) + + trainer.extend(evaluator, trigger=(1, "epoch")) + + if dist.get_rank() == 0: + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + num_snapshots = config.checkpoint.kbest_n + trainer.extend( + Snapshot( + mode='kbest', + max_size=num_snapshots, + indicator='VALID/LOSS', + less_better=True), + trigger=(1, 'epoch')) + # print(trainer.extensions) + # trainer.run() + self.trainer = trainer + + def run(self): + """The routine of the experiment after setup. This method is intended + to be used by the user. + """ + self.setup_updater() + with Timer("Training Done: {}"): + self.trainer.run() diff --git a/deepspeech/exps/u2_kaldi/bin/train.py b/deepspeech/exps/u2_kaldi/bin/train.py index 1dcd154d3..d909727f3 100644 --- a/deepspeech/exps/u2_kaldi/bin/train.py +++ b/deepspeech/exps/u2_kaldi/bin/train.py @@ -36,7 +36,7 @@ def main_sp(config, args): def main(config, args): - if args.device == "gpu" and args.nprocs > 1: + if args.nprocs > 0: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 6a932d751..d38afe25c 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -17,9 +17,11 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle from paddle import distributed as dist @@ -31,6 +33,7 @@ from deepspeech.io.dataloader import BatchDataLoader from deepspeech.models.u2 import U2Model from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate @@ -39,6 +42,7 @@ from deepspeech.utils import mp_tools from deepspeech.utils import text_grid from deepspeech.utils import utility from deepspeech.utils.log import Log +from deepspeech.utils.utility import UpdateConfig logger = Log(__name__).getlog() @@ -83,20 +87,34 @@ class U2Trainer(Trainer): train_conf = self.config.training start = time.time() + # forward utt, audio, audio_len, text, text_len = batch_data loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() @@ -167,43 +185,42 @@ class U2Trainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init') - self.lr_scheduler.step(self.iteration) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + self.after_train_batch() + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) @@ -300,10 +317,10 @@ class U2Trainer(Trainer): # model model_conf = config.model - model_conf.defrost() - model_conf.input_dim = self.train_loader.feat_dim - model_conf.output_dim = self.train_loader.vocab_size - model_conf.freeze() + with UpdateConfig(model_conf): + model_conf.input_dim = self.train_loader.feat_dim + model_conf.output_dim = self.train_loader.vocab_size + model = U2Model.from_config(model_conf) if self.parallel: model = paddle.DataParallel(model) @@ -429,9 +446,10 @@ class U2Tester(U2Trainer): len_refs += len_ref num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target, result)) + fout.write({"utt": utt, "ref": target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("One example error rate [%s] = %f" % (cfg.error_rate_type, error_rate_func(target, result))) @@ -456,7 +474,7 @@ class U2Tester(U2Trainer): errors_sum, len_refs, num_ins = 0.0, 0, 0 num_frames = 0.0 num_time = 0.0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): metrics = self.compute_metrics(*batch, fout=fout) num_frames += metrics['num_frames'] @@ -526,9 +544,8 @@ class U2Tester(U2Trainer): self.model.eval() logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") - stride_ms = self.config.collater.stride_ms - token_dict = self.args.char_list - + stride_ms = self.align_loader.collate_fn.stride_ms + token_dict = self.align_loader.collate_fn.vocab_list with open(self.args.result_file, 'w') as fout: # one example in batch for i, batch in enumerate(self.align_loader): @@ -537,7 +554,7 @@ class U2Tester(U2Trainer): # 1. Encoder encoder_out, encoder_mask = self.model._forward_encoder( feat, feats_length) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) + maxlen = encoder_out.shape[1] ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) @@ -545,26 +562,25 @@ class U2Tester(U2Trainer): ctc_probs = ctc_probs.squeeze(0) target = target.squeeze(0) alignment = ctc_utils.forced_align(ctc_probs, target) - logger.info("align ids", key[0], alignment) + logger.info(f"align ids: {key[0]} {alignment}") fout.write('{} {}\n'.format(key[0], alignment)) # 3. gen praat # segment alignment align_segs = text_grid.segment_alignment(alignment) - logger.info("align tokens", key[0], align_segs) + logger.info(f"align tokens: {key[0]}, {align_segs}") # IntervalTier, List["start end token\n"] subsample = utility.get_subsample(self.config) tierformat = text_grid.align_to_tierformat( align_segs, subsample, token_dict) # write tier - align_output_path = os.path.join( - os.path.dirname(self.args.result_file), "align") - tier_path = os.path.join(align_output_path, key[0] + ".tier") - with open(tier_path, 'w') as f: + align_output_path = Path(self.args.result_file).parent / "align" + align_output_path.mkdir(parents=True, exist_ok=True) + tier_path = align_output_path / (key[0] + ".tier") + with tier_path.open('w') as f: f.writelines(tierformat) # write textgrid - textgrid_path = os.path.join(align_output_path, - key[0] + ".TextGrid") + textgrid_path = align_output_path / (key[0] + ".TextGrid") second_per_frame = 1. / (1000. / stride_ms) # 25ms window, 10ms stride second_per_example = ( @@ -572,7 +588,7 @@ class U2Tester(U2Trainer): text_grid.generate_textgrid( maxtime=second_per_example, intervals=tierformat, - output=textgrid_path) + output=str(textgrid_path)) def run_align(self): self.resume_or_scratch() @@ -623,7 +639,7 @@ class U2Tester(U2Trainer): def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') self.setup_output_dir() self.setup_checkpointer() diff --git a/deepspeech/exps/u2_st/bin/train.py b/deepspeech/exps/u2_st/bin/train.py index 86a0f0000..1e6a746b8 100644 --- a/deepspeech/exps/u2_st/bin/train.py +++ b/deepspeech/exps/u2_st/bin/train.py @@ -30,7 +30,7 @@ def main_sp(config, args): def main(config, args): - if args.device == "gpu" and args.nprocs > 1: + if args.nprocs > 0: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index 5734e15f5..e4e70292c 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -17,9 +17,11 @@ import os import sys import time from collections import defaultdict +from contextlib import nullcontext from pathlib import Path from typing import Optional +import jsonlines import numpy as np import paddle from paddle import distributed as dist @@ -37,6 +39,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.u2_st import U2STModel from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.scheduler import WarmupLR +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import bleu_score from deepspeech.utils import ctc_utils @@ -45,6 +48,7 @@ from deepspeech.utils import mp_tools from deepspeech.utils import text_grid from deepspeech.utils import utility from deepspeech.utils.log import Log +from deepspeech.utils.utility import UpdateConfig logger = Log(__name__).getlog() @@ -83,6 +87,7 @@ class U2STTrainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() + # forward utt, audio, audio_len, text, text_len = batch_data if isinstance(text, list) and isinstance(text_len, list): # joint training with ASR. Two decoding texts [translation, transcription] @@ -94,18 +99,30 @@ class U2STTrainer(Trainer): else: loss, st_loss, attention_loss, ctc_loss = self.model( audio, audio_len, text, text_len) + # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - loss.backward() - layer_tools.print_grads(self.model, print_func=None) - losses_np = {'loss': float(loss) * train_conf.accum_grad} - losses_np['st_loss'] = float(st_loss) if attention_loss: losses_np['att_loss'] = float(attention_loss) if ctc_loss: losses_np['ctc_loss'] = float(ctc_loss) + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step if (batch_index + 1) % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() @@ -182,46 +199,42 @@ class U2STTrainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init') - - self.lr_scheduler.step(self.iteration) - if self.parallel: - self.train_loader.batch_sampler.set_epoch(self.epoch) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + self.after_train_batch() + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) @@ -327,10 +340,10 @@ class U2STTrainer(Trainer): def setup_model(self): config = self.config model_conf = config.model - model_conf.defrost() - model_conf.input_dim = self.train_loader.collate_fn.feature_size - model_conf.output_dim = self.train_loader.collate_fn.vocab_size - model_conf.freeze() + with UpdateConfig(model_conf): + model_conf.input_dim = self.train_loader.collate_fn.feature_size + model_conf.output_dim = self.train_loader.collate_fn.vocab_size + model = U2STModel.from_config(model_conf) if self.parallel: @@ -467,8 +480,10 @@ class U2STTester(U2STTrainer): len_refs += len(target.split()) num_ins += 1 if fout: - fout.write(utt + " " + result + "\n") - logger.info("\nReference: %s\nHypothesis: %s" % (target, result)) + fout.write({"utt": utt, "ref": target, "hyp": result}) + logger.info(f"Utt: {utt}") + logger.info(f"Ref: {target}") + logger.info(f"Hyp: {result}") logger.info("One example BLEU = %s" % (bleu_func([result], [[target]]).prec_str)) @@ -496,7 +511,7 @@ class U2STTester(U2STTrainer): len_refs, num_ins = 0, 0 num_frames = 0.0 num_time = 0.0 - with open(self.args.result_file, 'w') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): metrics = self.compute_translation_metrics( *batch, bleu_func=bleu_func, fout=fout) @@ -569,7 +584,7 @@ class U2STTester(U2STTrainer): # 1. Encoder encoder_out, encoder_mask = self.model._forward_encoder( feat, feats_length) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) + maxlen = encoder_out.shape[1] ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) @@ -577,26 +592,25 @@ class U2STTester(U2STTrainer): ctc_probs = ctc_probs.squeeze(0) target = target.squeeze(0) alignment = ctc_utils.forced_align(ctc_probs, target) - logger.info("align ids", key[0], alignment) + logger.info(f"align ids: {key[0]} {alignment}") fout.write('{} {}\n'.format(key[0], alignment)) # 3. gen praat # segment alignment align_segs = text_grid.segment_alignment(alignment) - logger.info("align tokens", key[0], align_segs) + logger.info(f"align tokens: {key[0]}, {align_segs}") # IntervalTier, List["start end token\n"] subsample = utility.get_subsample(self.config) tierformat = text_grid.align_to_tierformat( align_segs, subsample, token_dict) # write tier - align_output_path = os.path.join( - os.path.dirname(self.args.result_file), "align") - tier_path = os.path.join(align_output_path, key[0] + ".tier") - with open(tier_path, 'w') as f: + align_output_path = Path(self.args.result_file).parent / "align" + align_output_path.mkdir(parents=True, exist_ok=True) + tier_path = align_output_path / (key[0] + ".tier") + with tier_path.open('w') as f: f.writelines(tierformat) # write textgrid - textgrid_path = os.path.join(align_output_path, - key[0] + ".TextGrid") + textgrid_path = align_output_path / (key[0] + ".TextGrid") second_per_frame = 1. / (1000. / stride_ms) # 25ms window, 10ms stride second_per_example = ( @@ -604,7 +618,7 @@ class U2STTester(U2STTrainer): text_grid.generate_textgrid( maxtime=second_per_example, intervals=tierformat, - output=textgrid_path) + output=str(textgrid_path)) def run_align(self): self.resume_or_scratch() @@ -650,7 +664,7 @@ class U2STTester(U2STTrainer): def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') self.setup_output_dir() self.setup_checkpointer() diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index e4364f70a..7dc01c40a 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -76,7 +76,7 @@ class TextFeaturizer(): Args: text (str): Text. - + Returns: List[int]: List of token indices. """ @@ -89,7 +89,7 @@ class TextFeaturizer(): def defeaturize(self, idxs): """Convert a list of token indices to text string, - ignore index after eos_id. + ignore index after eos_id. Args: idxs (List[int]): List of token indices. @@ -196,7 +196,12 @@ class TextFeaturizer(): [(idx, token) for (idx, token) in enumerate(vocab_list)]) token2id = dict( [(token, idx) for (idx, token) in enumerate(vocab_list)]) - - unk_id = vocab_list.index(UNK) - eos_id = vocab_list.index(EOS) + if UNK in vocab_list: + unk_id = vocab_list.index(UNK) + else: + unk_id = -1 + if EOS in vocab_list: + eos_id = vocab_list.index(EOS) + else: + eos_id = -1 return token2id, id2token, vocab_list, unk_id, eos_id diff --git a/deepspeech/frontend/normalizer.py b/deepspeech/frontend/normalizer.py index 73b3a4ba6..6ace4fc6d 100644 --- a/deepspeech/frontend/normalizer.py +++ b/deepspeech/frontend/normalizer.py @@ -130,7 +130,8 @@ class FeatureNormalizer(object): def _read_mean_std_from_file(self, filepath, eps=1e-20): """Load mean and std from file.""" - mean, istd = load_cmvn(filepath, filetype='json') + filetype = filepath.split(".")[-1] + mean, istd = load_cmvn(filepath, filetype=filetype) self._mean = np.expand_dims(mean, axis=0) self._istd = np.expand_dims(istd, axis=0) diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index 72dfc98dd..f7e2cb214 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Contains data helper functions.""" -import codecs import json import math from typing import List from typing import Optional from typing import Text +import jsonlines import numpy as np from deepspeech.utils.log import Log @@ -69,19 +69,19 @@ def read_manifest( Args: manifest_path ([type]): Manifest file to load and parse. - max_input_len ([type], optional): maximum output seq length, - in seconds for raw wav, in frame numbers for feature data. + max_input_len ([type], optional): maximum output seq length, + in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). - min_input_len (float, optional): minimum input seq length, - in seconds for raw wav, in frame numbers for feature data. + min_input_len (float, optional): minimum input seq length, + in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. - max_output_len (float, optional): maximum input seq length, + max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. - min_output_len (float, optional): minimum input seq length, + min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. - max_output_input_ratio (float, optional): + max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. - min_output_input_ratio (float, optional): + min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. Raises: @@ -92,26 +92,22 @@ def read_manifest( """ manifest = [] - for json_line in codecs.open(manifest_path, 'r', 'utf-8'): - try: - json_data = json.loads(json_line) - except Exception as e: - raise IOError("Error reading manifest: %s" % str(e)) - - feat_len = json_data["feat_shape"][ - 0] if 'feat_shape' in json_data else 1.0 - token_len = json_data["token_shape"][ - 0] if 'token_shape' in json_data else 1.0 - conditions = [ - feat_len >= min_input_len, - feat_len <= max_input_len, - token_len >= min_output_len, - token_len <= max_output_len, - token_len / feat_len >= min_output_input_ratio, - token_len / feat_len <= max_output_input_ratio, - ] - if all(conditions): - manifest.append(json_data) + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + feat_len = json_data["feat_shape"][ + 0] if 'feat_shape' in json_data else 1.0 + token_len = json_data["token_shape"][ + 0] if 'token_shape' in json_data else 1.0 + conditions = [ + feat_len >= min_input_len, + feat_len <= max_input_len, + token_len >= min_output_len, + token_len <= max_output_len, + token_len / feat_len >= min_output_input_ratio, + token_len / feat_len <= max_output_input_ratio, + ] + if all(conditions): + manifest.append(json_data) return manifest @@ -131,7 +127,7 @@ def rms_to_dbfs(rms: float): """Root Mean Square to dBFS. https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/ Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB. - + dB = dBFS + 3.0103 dBFS = db - 3.0103 e.g. 0 dB = -3.0103 dBFS @@ -146,26 +142,26 @@ def rms_to_dbfs(rms: float): def max_dbfs(sample_data: np.ndarray): - """Peak dBFS based on the maximum energy sample. + """Peak dBFS based on the maximum energy sample. Args: sample_data ([np.ndarray]): float array, [-1, 1]. Returns: - float: dBFS + float: dBFS """ # Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization. return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data)))) def mean_dbfs(sample_data): - """Peak dBFS based on the RMS energy. + """Peak dBFS based on the RMS energy. Args: sample_data ([np.ndarray]): float array, [-1, 1]. Returns: - float: dBFS + float: dBFS """ return rms_to_dbfs( math.sqrt(np.mean(np.square(sample_data, dtype=np.float64)))) @@ -185,7 +181,7 @@ def gain_db_to_ratio(gain_db: float): def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103): """Nomalize audio to dBFS. - + Args: sample_data (np.ndarray): input wave samples, [-1, 1]. dbfs (float, optional): target dBFS. Defaults to -3.0103. @@ -284,6 +280,13 @@ def load_cmvn(cmvn_file: str, filetype: str): cmvn = _load_json_cmvn(cmvn_file) elif filetype == "kaldi": cmvn = _load_kaldi_cmvn(cmvn_file) + elif filetype == "npz": + eps = 1e-14 + npzfile = np.load(cmvn_file) + mean = np.squeeze(npzfile["mean"]) + std = np.squeeze(npzfile["std"]) + istd = 1 / (std + eps) + cmvn = [mean, istd] else: raise ValueError(f"cmvn file type no support: {filetype}") return cmvn[0], cmvn[1] diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index df3004790..15b89ab9f 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -292,10 +292,6 @@ class SpeechCollator(): olens = np.array(text_lens).astype(np.int64) return utts, xs_pad, ilens, ys_pad, olens - @property - def manifest(self): - return self._manifest - @property def vocab_size(self): return self._speech_featurizer.vocab_size diff --git a/deepspeech/io/dataloader.py b/deepspeech/io/dataloader.py index a35a0bc09..310f5f581 100644 --- a/deepspeech/io/dataloader.py +++ b/deepspeech/io/dataloader.py @@ -44,7 +44,7 @@ def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]], def batch_collate(x): - """de-tuple. + """de-minibatch, since user compose batch. Args: x (List[Tuple]): [(utts, xs, ilens, ys, olens)] diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index d1fe04707..56e534756 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -76,19 +76,19 @@ class ManifestDataset(Dataset): Args: manifest_path (str): manifest josn file path - max_input_len ([type], optional): maximum output seq length, + max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). - min_input_len (float, optional): minimum input seq length, + min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. - max_output_len (float, optional): maximum input seq length, + max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. - min_output_len (float, optional): minimum input seq length, + min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. - max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. + max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. - + """ super().__init__() @@ -147,3 +147,131 @@ class TransformDataset(Dataset): def __getitem__(self, idx): """[] operator.""" return self.converter([self.reader(self.data[idx], return_uttid=True)]) + + +class AudioDataset(Dataset): + def __init__(self, + data_file, + max_length=10240, + min_length=0, + token_max_length=200, + token_min_length=1, + batch_type='static', + batch_size=1, + max_frames_in_batch=0, + sort=True, + raw_wav=True, + stride_ms=10): + """Dataset for loading audio data. + Attributes:: + data_file: input data file + Plain text data file, each line contains following 7 fields, + which is split by '\t': + utt:utt1 + feat:tmp/data/file1.wav or feat:tmp/data/fbank.ark:30 + feat_shape: 4.95(in seconds) or feat_shape:495,80(495 is in frames) + text:i love you + token: i l o v e y o u + tokenid: int id of this token + token_shape: M,N # M is the number of token, N is vocab size + max_length: drop utterance which is greater than max_length(10ms), unit 10ms. + min_length: drop utterance which is less than min_length(10ms), unit 10ms. + token_max_length: drop utterance which is greater than token_max_length, + especially when use char unit for english modeling + token_min_length: drop utterance which is less than token_max_length + batch_type: static or dynamic, see max_frames_in_batch(dynamic) + batch_size: number of utterances in a batch, + it's for static batch size. + max_frames_in_batch: max feature frames in a batch, + when batch_type is dynamic, it's for dynamic batch size. + Then batch_size is ignored, we will keep filling the + batch until the total frames in batch up to max_frames_in_batch. + sort: whether to sort all data, so the utterance with the same + length could be filled in a same batch. + raw_wav: use raw wave or extracted featute. + if raw wave is used, dynamic waveform-level augmentation could be used + and the feature is extracted by torchaudio. + if extracted featute(e.g. by kaldi) is used, only feature-level + augmentation such as specaug could be used. + """ + assert batch_type in ['static', 'dynamic'] + # read manifest + data = read_manifest(data_file) + if sort: + data = sorted(data, key=lambda x: x["feat_shape"][0]) + if raw_wav: + assert data[0]['feat'].split(':')[0].splitext()[-1] not in ('.ark', + '.scp') + data = map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms)) + + self.input_dim = data[0]['feat_shape'][1] + self.output_dim = data[0]['token_shape'][1] + + # with open(data_file, 'r') as f: + # for line in f: + # arr = line.strip().split('\t') + # if len(arr) != 7: + # continue + # key = arr[0].split(':')[1] + # tokenid = arr[5].split(':')[1] + # output_dim = int(arr[6].split(':')[1].split(',')[1]) + # if raw_wav: + # wav_path = ':'.join(arr[1].split(':')[1:]) + # duration = int(float(arr[2].split(':')[1]) * 1000 / 10) + # data.append((key, wav_path, duration, tokenid)) + # else: + # feat_ark = ':'.join(arr[1].split(':')[1:]) + # feat_info = arr[2].split(':')[1].split(',') + # feat_dim = int(feat_info[1].strip()) + # num_frames = int(feat_info[0].strip()) + # data.append((key, feat_ark, num_frames, tokenid)) + # self.input_dim = feat_dim + # self.output_dim = output_dim + + valid_data = [] + for i in range(len(data)): + length = data[i]['feat_shape'][0] + token_length = data[i]['token_shape'][0] + # remove too lang or too short utt for both input and output + # to prevent from out of memory + if length > max_length or length < min_length: + # logging.warn('ignore utterance {} feature {}'.format( + # data[i][0], length)) + pass + elif token_length > token_max_length or token_length < token_min_length: + pass + else: + valid_data.append(data[i]) + data = valid_data + + self.minibatch = [] + num_data = len(data) + # Dynamic batch size + if batch_type == 'dynamic': + assert (max_frames_in_batch > 0) + self.minibatch.append([]) + num_frames_in_batch = 0 + for i in range(num_data): + length = data[i]['feat_shape'][0] + num_frames_in_batch += length + if num_frames_in_batch > max_frames_in_batch: + self.minibatch.append([]) + num_frames_in_batch = length + self.minibatch[-1].append(data[i]) + # Static batch size + else: + cur = 0 + while cur < num_data: + end = min(cur + batch_size, num_data) + item = [] + for i in range(cur, end): + item.append(data[i]) + self.minibatch.append(item) + cur = end + + def __len__(self): + return len(self.minibatch) + + def __getitem__(self, idx): + instance = self.minibatch[idx] + return instance["utt"], instance["feat"], instance["text"] diff --git a/deepspeech/models/ds2/conv.py b/deepspeech/models/ds2/conv.py index ce962a445..9548af0a2 100644 --- a/deepspeech/models/ds2/conv.py +++ b/deepspeech/models/ds2/conv.py @@ -106,11 +106,9 @@ class ConvBn(nn.Layer): # reset padding part to 0 masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - # TODO(Hui Zhang): not support bool multiply - # masks = masks.type_as(x) - masks = masks.astype(x.dtype) - x = x.multiply(masks) - + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks return x, x_len diff --git a/deepspeech/models/ds2/deepspeech2.py b/deepspeech/models/ds2/deepspeech2.py index 5f8f32557..dda26358b 100644 --- a/deepspeech/models/ds2/deepspeech2.py +++ b/deepspeech/models/ds2/deepspeech2.py @@ -128,8 +128,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=3, #Number of stacking RNN layers. rnn_layer_size=1024, #RNN layer size (number of RNN cells). use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. - )) + share_rnn_weights=True, #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + ctc_grad_norm_type='instance', )) if config is not None: config.merge_from_other_cfg(default) return default @@ -141,7 +141,9 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=3, rnn_size=1024, use_gru=False, - share_rnn_weights=True): + share_rnn_weights=True, + blank_id=0, + ctc_grad_norm_type='instance'): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, @@ -156,10 +158,11 @@ class DeepSpeech2Model(nn.Layer): self.decoder = CTCDecoder( odim=dict_size, # is in vocab enc_n_units=self.encoder.output_size, - blank_id=0, # first token is + blank_id=blank_id, dropout_rate=0.0, reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type=ctc_grad_norm_type) def forward(self, audio, audio_len, text, text_len): """Compute Model loss @@ -221,7 +224,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, use_gru=config.model.use_gru, - share_rnn_weights=config.model.share_rnn_weights) + share_rnn_weights=config.model.share_rnn_weights, + blank_id=config.model.blank_id) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -246,7 +250,8 @@ class DeepSpeech2Model(nn.Layer): num_rnn_layers=config.num_rnn_layers, rnn_size=config.rnn_layer_size, use_gru=config.use_gru, - share_rnn_weights=config.share_rnn_weights) + share_rnn_weights=config.share_rnn_weights, + blank_id=config.blank_id) return model @@ -258,7 +263,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): num_rnn_layers=3, rnn_size=1024, use_gru=False, - share_rnn_weights=True): + share_rnn_weights=True, + blank_id=0): super().__init__( feat_size=feat_size, dict_size=dict_size, @@ -266,7 +272,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): num_rnn_layers=num_rnn_layers, rnn_size=rnn_size, use_gru=use_gru, - share_rnn_weights=share_rnn_weights) + share_rnn_weights=share_rnn_weights, + blank_id=blank_id) def forward(self, audio, audio_len): """export model function diff --git a/deepspeech/models/ds2/rnn.py b/deepspeech/models/ds2/rnn.py index 3ff91d0af..3fc52a378 100644 --- a/deepspeech/models/ds2/rnn.py +++ b/deepspeech/models/ds2/rnn.py @@ -308,7 +308,8 @@ class RNNStack(nn.Layer): x, x_len = rnn(x, x_len) masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(-1) # [B, T, 1] - # TODO(Hui Zhang): not support bool multiply - masks = masks.astype(x.dtype) - x = x.multiply(masks) + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks + return x, x_len diff --git a/deepspeech/models/ds2_online/deepspeech2.py b/deepspeech/models/ds2_online/deepspeech2.py index f597a5783..29d207c44 100644 --- a/deepspeech/models/ds2_online/deepspeech2.py +++ b/deepspeech/models/ds2_online/deepspeech2.py @@ -254,6 +254,7 @@ class DeepSpeech2ModelOnline(nn.Layer): num_fc_layers=2, fc_layers_size_list=[512, 256], use_gru=True, #Use gru if set True. Use simple rnn if set False. + blank_id=0, # index of blank in vocob.txt )) if config is not None: config.merge_from_other_cfg(default) @@ -268,7 +269,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=False): + use_gru=False, + blank_id=0): super().__init__() self.encoder = CRNNEncoder( feat_size=feat_size, @@ -284,10 +286,11 @@ class DeepSpeech2ModelOnline(nn.Layer): self.decoder = CTCDecoder( odim=dict_size, # is in vocab enc_n_units=self.encoder.output_size, - blank_id=0, # first token is + blank_id=blank_id, dropout_rate=0.0, reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type='instance') def forward(self, audio, audio_len, text, text_len): """Compute Model loss @@ -353,7 +356,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction=config.model.rnn_direction, num_fc_layers=config.model.num_fc_layers, fc_layers_size_list=config.model.fc_layers_size_list, - use_gru=config.model.use_gru) + use_gru=config.model.use_gru, + blank_id=config.model.blank_id) infos = Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") @@ -380,7 +384,8 @@ class DeepSpeech2ModelOnline(nn.Layer): rnn_direction=config.rnn_direction, num_fc_layers=config.num_fc_layers, fc_layers_size_list=config.fc_layers_size_list, - use_gru=config.use_gru) + use_gru=config.use_gru, + blank_id=config.blank_id) return model @@ -394,7 +399,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): rnn_direction='forward', num_fc_layers=2, fc_layers_size_list=[512, 256], - use_gru=False): + use_gru=False, + blank_id=0): super().__init__( feat_size=feat_size, dict_size=dict_size, @@ -404,7 +410,8 @@ class DeepSpeech2InferModelOnline(DeepSpeech2ModelOnline): rnn_direction=rnn_direction, num_fc_layers=num_fc_layers, fc_layers_size_list=fc_layers_size_list, - use_gru=use_gru) + use_gru=use_gru, + blank_id=blank_id) def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box): diff --git a/deepspeech/models/u2/__init__.py b/deepspeech/models/u2/__init__.py new file mode 100644 index 000000000..a9010f1d0 --- /dev/null +++ b/deepspeech/models/u2/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .u2 import U2InferModel +from .u2 import U2Model +from .updater import U2Evaluator +from .updater import U2Updater + +__all__ = ["U2Model", "U2InferModel", "U2Evaluator", "U2Updater"] diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2/u2.py similarity index 96% rename from deepspeech/models/u2.py rename to deepspeech/models/u2/u2.py index c1a35560a..46bbd102f 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2/u2.py @@ -48,6 +48,7 @@ from deepspeech.utils.tensor_utils import add_sos_eos from deepspeech.utils.tensor_utils import pad_sequence from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.utility import log_add +from deepspeech.utils.utility import UpdateConfig __all__ = ["U2Model", "U2InferModel"] @@ -115,7 +116,8 @@ class U2BaseModel(nn.Layer): ctc_weight: float=0.5, ignore_id: int=IGNORE_ID, lsm_weight: float=0.0, - length_normalized_loss: bool=False): + length_normalized_loss: bool=False, + **kwargs): assert 0.0 <= ctc_weight <= 1.0, ctc_weight super().__init__() @@ -162,10 +164,7 @@ class U2BaseModel(nn.Layer): encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_time = time.time() - start #logger.debug(f"encoder time: {encoder_time}") - #TODO(Hui Zhang): sum not support bool type - #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] - encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( - 1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] # 2a. Attention-decoder branch loss_att = None @@ -299,8 +298,8 @@ class U2BaseModel(nn.Layer): speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) - encoder_dim = encoder_out.size(2) + maxlen = encoder_out.shape[1] + encoder_dim = encoder_out.shape[2] running_size = batch_size * beam_size encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) @@ -320,8 +319,7 @@ class U2BaseModel(nn.Layer): # 2. Decoder forward step by step for i in range(1, maxlen + 1): # Stop if all batch and all beam produce eos - # TODO(Hui Zhang): if end_flag.sum() == running_size: - if end_flag.cast(paddle.int64).sum() == running_size: + if end_flag.sum() == running_size: break # 2.1 Forward decoder step @@ -406,10 +404,8 @@ class U2BaseModel(nn.Layer): encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) - maxlen = encoder_out.size(1) - # (TODO Hui Zhang): bool no support reduce_sum - # encoder_out_lens = encoder_mask.squeeze(1).sum(1) - encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int).sum(1) + maxlen = encoder_out.shape[1] + encoder_out_lens = encoder_mask.squeeze(1).sum(1) ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) @@ -459,7 +455,7 @@ class U2BaseModel(nn.Layer): speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) + maxlen = encoder_out.shape[1] ctc_probs = self.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) @@ -587,7 +583,7 @@ class U2BaseModel(nn.Layer): encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( - (beam_size, 1, encoder_out.size(1)), dtype=paddle.bool) + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) @@ -667,9 +663,7 @@ class U2BaseModel(nn.Layer): xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) - # @jit.to_static([ - # paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D] - # ]) + # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: """ Export interface for c++ call, apply linear transform and log softmax before ctc @@ -696,13 +690,13 @@ class U2BaseModel(nn.Layer): Returns: paddle.Tensor: decoder output, (B, L) """ - assert encoder_out.size(0) == 1 - num_hyps = hyps.size(0) - assert hyps_lens.size(0) == num_hyps + assert encoder_out.shape[0] == 1 + num_hyps = hyps.shape[0] + assert hyps_lens.shape[0] == num_hyps encoder_out = encoder_out.repeat(num_hyps, 1, 1) # (B, 1, T) encoder_mask = paddle.ones( - [num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool) + [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool) # (num_hyps, max_hyps_len, vocab_size) decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, hyps_lens) @@ -757,7 +751,7 @@ class U2BaseModel(nn.Layer): Returns: List[List[int]]: transcripts. """ - batch_size = feats.size(0) + batch_size = feats.shape[0] if decoding_method in ['ctc_prefix_beam_search', 'attention_rescoring'] and batch_size > 1: logger.fatal( @@ -785,7 +779,7 @@ class U2BaseModel(nn.Layer): # result in List[int], change it to List[List[int]] for compatible # with other batch decoding mode elif decoding_method == 'ctc_prefix_beam_search': - assert feats.size(0) == 1 + assert feats.shape[0] == 1 hyp = self.ctc_prefix_beam_search( feats, feats_lengths, @@ -795,7 +789,7 @@ class U2BaseModel(nn.Layer): simulate_streaming=simulate_streaming) hyps = [hyp] elif decoding_method == 'attention_rescoring': - assert feats.size(0) == 1 + assert feats.shape[0] == 1 hyp = self.attention_rescoring( feats, feats_lengths, @@ -836,6 +830,7 @@ class U2Model(U2BaseModel): Returns: int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc """ + # cmvn if configs['cmvn_file'] is not None: mean, istd = load_cmvn(configs['cmvn_file'], configs['cmvn_file_type']) @@ -845,11 +840,13 @@ class U2Model(U2BaseModel): else: global_cmvn = None + # input & output dim input_dim = configs['input_dim'] vocab_size = configs['output_dim'] assert input_dim != 0, input_dim assert vocab_size != 0, vocab_size + # encoder encoder_type = configs.get('encoder', 'transformer') logger.info(f"U2 Encoder type: {encoder_type}") if encoder_type == 'transformer': @@ -861,16 +858,21 @@ class U2Model(U2BaseModel): else: raise ValueError(f"not support encoder type:{encoder_type}") + # decoder decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) + + # ctc decoder and ctc loss + model_conf = configs['model_conf'] ctc = CTCDecoder( odim=vocab_size, enc_n_units=encoder.output_size(), blank_id=0, - dropout_rate=0.0, + dropout_rate=model_conf['ctc_dropoutrate'], reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type=model_conf['ctc_grad_norm_type']) return vocab_size, encoder, decoder, ctc @@ -902,10 +904,10 @@ class U2Model(U2BaseModel): Returns: DeepSpeech2Model: The model built from pretrained result. """ - config.defrost() - config.input_dim = dataloader.collate_fn.feature_size - config.output_dim = dataloader.collate_fn.vocab_size - config.freeze() + with UpdateConfig(config): + config.input_dim = dataloader.collate_fn.feature_size + config.output_dim = dataloader.collate_fn.vocab_size + model = cls.from_config(config) if checkpoint_path: diff --git a/deepspeech/models/u2/updater.py b/deepspeech/models/u2/updater.py new file mode 100644 index 000000000..7b70ca047 --- /dev/null +++ b/deepspeech/models/u2/updater.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext + +import paddle +from paddle import distributed as dist + +from deepspeech.training.extensions.evaluator import StandardEvaluator +from deepspeech.training.reporter import report +from deepspeech.training.timer import Timer +from deepspeech.training.updaters.standard_updater import StandardUpdater +from deepspeech.utils import layer_tools +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +class U2Evaluator(StandardEvaluator): + def __init__(self, model, dataloader): + super().__init__(model, dataloader) + self.msg = "" + self.num_seen_utts = 0 + self.total_loss = 0.0 + + def evaluate_core(self, batch): + self.msg = "Valid: Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + loss, attention_loss, ctc_loss = self.model(*batch[1:]) + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + self.num_seen_utts += num_utts + self.total_loss += float(loss) * num_utts + + losses_dict['loss'] = float(loss) + if attention_loss: + losses_dict['att_loss'] = float(attention_loss) + if ctc_loss: + losses_dict['ctc_loss'] = float(ctc_loss) + + for k, v in losses_dict.items(): + report("eval/" + k, v) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + logger.info(self.msg) + return self.total_loss, self.num_seen_utts + + +class U2Updater(StandardUpdater): + def __init__(self, + model, + optimizer, + scheduler, + dataloader, + init_state=None, + accum_grad=1, + **kwargs): + super().__init__( + model, optimizer, scheduler, dataloader, init_state=init_state) + self.accum_grad = accum_grad + self.forward_count = 0 + self.msg = "" + + def update_core(self, batch): + """One Step + + Args: + batch (List[Object]): utts, xs, xlens, ys, ylens + """ + losses_dict = {} + self.msg = "Rank: {}, ".format(dist.get_rank()) + + # forward + batch_size = batch[1].shape[0] + loss, attention_loss, ctc_loss = self.model(*batch[1:]) + # loss div by `batch_size * accum_grad` + loss /= self.accum_grad + + # loss backward + if (self.forward_count + 1) != self.accum_grad: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # loss info + losses_dict['loss'] = float(loss) * self.accum_grad + if attention_loss: + losses_dict['att_loss'] = float(attention_loss) + if ctc_loss: + losses_dict['ctc_loss'] = float(ctc_loss) + # report loss + for k, v in losses_dict.items(): + report("train/" + k, v) + # loss msg + self.msg += "batch size: {}, ".format(batch_size) + self.msg += "accum: {}, ".format(self.accum_grad) + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + # Truncate the graph + loss.detach() + + # update parameters + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + + self.optimizer.step() + self.optimizer.clear_grad() + self.scheduler.step() + + def update(self): + # model is default in train mode + + # training for a step is implemented here + with Timer("data time cost:{}"): + batch = self.read_batch() + with Timer("step time cost:{}"): + self.update_core(batch) + + # #iterations with accum_grad > 1 + # Ref.: https://github.com/espnet/espnet/issues/777 + if self.forward_count == 0: + self.state.iteration += 1 + if self.updates_per_epoch is not None: + if self.state.iteration % self.updates_per_epoch == 0: + self.state.epoch += 1 diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py index b725cc359..a3d99942f 100644 --- a/deepspeech/models/u2_st.py +++ b/deepspeech/models/u2_st.py @@ -42,6 +42,7 @@ from deepspeech.utils import layer_tools from deepspeech.utils.log import Log from deepspeech.utils.tensor_utils import add_sos_eos from deepspeech.utils.tensor_utils import th_accuracy +from deepspeech.utils.utility import UpdateConfig __all__ = ["U2STModel", "U2STInferModel"] @@ -163,10 +164,7 @@ class U2STBaseModel(nn.Layer): encoder_out, encoder_mask = self.encoder(speech, speech_lengths) encoder_time = time.time() - start #logger.debug(f"encoder time: {encoder_time}") - #TODO(Hui Zhang): sum not support bool type - #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] - encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( - 1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] # 2a. ST-decoder branch start = time.time() @@ -342,8 +340,8 @@ class U2STBaseModel(nn.Layer): speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) - maxlen = encoder_out.size(1) - encoder_dim = encoder_out.size(2) + maxlen = encoder_out.shape[1] + encoder_dim = encoder_out.shape[2] running_size = batch_size * beam_size encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) @@ -363,8 +361,7 @@ class U2STBaseModel(nn.Layer): # 2. Decoder forward step by step for i in range(1, maxlen + 1): # Stop if all batch and all beam produce eos - # TODO(Hui Zhang): if end_flag.sum() == running_size: - if end_flag.cast(paddle.int64).sum() == running_size: + if end_flag.sum() == running_size: break # 2.1 Forward decoder step @@ -417,26 +414,26 @@ class U2STBaseModel(nn.Layer): best_hyps = best_hyps[:, 1:] return best_hyps - @jit.to_static + # @jit.to_static def subsampling_rate(self) -> int: """ Export interface for c++ call, return subsampling_rate of the model """ return self.encoder.embed.subsampling_rate - @jit.to_static + # @jit.to_static def right_context(self) -> int: """ Export interface for c++ call, return right_context of the model """ return self.encoder.embed.right_context - @jit.to_static + # @jit.to_static def sos_symbol(self) -> int: """ Export interface for c++ call, return sos symbol id of the model """ return self.sos - @jit.to_static + # @jit.to_static def eos_symbol(self) -> int: """ Export interface for c++ call, return eos symbol id of the model """ @@ -472,7 +469,7 @@ class U2STBaseModel(nn.Layer): xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) - @jit.to_static + # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: """ Export interface for c++ call, apply linear transform and log softmax before ctc @@ -499,13 +496,13 @@ class U2STBaseModel(nn.Layer): Returns: paddle.Tensor: decoder output, (B, L) """ - assert encoder_out.size(0) == 1 - num_hyps = hyps.size(0) - assert hyps_lens.size(0) == num_hyps + assert encoder_out.shape[0] == 1 + num_hyps = hyps.shape[0] + assert hyps_lens.shape[0] == num_hyps encoder_out = encoder_out.repeat(num_hyps, 1, 1) # (B, 1, T) encoder_mask = paddle.ones( - [num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool) + [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool) # (num_hyps, max_hyps_len, vocab_size) decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, hyps_lens) @@ -560,7 +557,7 @@ class U2STBaseModel(nn.Layer): Returns: List[List[int]]: transcripts. """ - batch_size = feats.size(0) + batch_size = feats.shape[0] if decoding_method == 'fullsentence': hyps = self.translate( @@ -647,13 +644,16 @@ class U2STModel(U2STBaseModel): decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) + # ctc decoder and ctc loss + model_conf = configs['model_conf'] ctc = CTCDecoder( odim=vocab_size, enc_n_units=encoder.output_size(), blank_id=0, - dropout_rate=0.0, + dropout_rate=model_conf['ctc_dropout_rate'], reduction=True, # sum - batch_average=True) # sum / batch_size + batch_average=True, # sum / batch_size + grad_norm_type=model_conf['ctc_grad_norm_type']) return vocab_size, encoder, (st_decoder, decoder, ctc) else: @@ -687,10 +687,10 @@ class U2STModel(U2STBaseModel): Returns: DeepSpeech2Model: The model built from pretrained result. """ - config.defrost() - config.input_dim = dataloader.collate_fn.feature_size - config.output_dim = dataloader.collate_fn.vocab_size - config.freeze() + with UpdateConfig(config): + config.input_dim = dataloader.collate_fn.feature_size + config.output_dim = dataloader.collate_fn.vocab_size + model = cls.from_config(config) if checkpoint_path: diff --git a/deepspeech/modules/activation.py b/deepspeech/modules/activation.py index 30132775e..3cb8729e1 100644 --- a/deepspeech/modules/activation.py +++ b/deepspeech/modules/activation.py @@ -15,12 +15,13 @@ from collections import OrderedDict import paddle from paddle import nn +from paddle.nn import functional as F from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"] +__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock", "GLU"] def brelu(x, t_min=0.0, t_max=24.0, name=None): @@ -30,6 +31,17 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): return x.maximum(t_min).minimum(t_max) +class GLU(nn.Layer): + """Gated Linear Units (GLU) Layer""" + + def __init__(self, dim: int=-1): + super().__init__() + self.dim = dim + + def forward(self, xs): + return F.glu(xs, axis=self.dim) + + class LinearGLUBlock(nn.Layer): """A linear Gated Linear Units (GLU) block.""" @@ -133,13 +145,18 @@ def get_activation(act): """Return activation function.""" # Lazy load to avoid unused import activation_funcs = { + "hardshrink": paddle.nn.Hardshrink, + "hardswish": paddle.nn.Hardswish, "hardtanh": paddle.nn.Hardtanh, "tanh": paddle.nn.Tanh, "relu": paddle.nn.ReLU, + "relu6": paddle.nn.ReLU6, + "leakyrelu": paddle.nn.LeakyReLU, "selu": paddle.nn.SELU, "swish": paddle.nn.Swish, "gelu": paddle.nn.GELU, - "brelu": brelu, + "glu": GLU, + "elu": paddle.nn.ELU, } return activation_funcs[act]() diff --git a/deepspeech/modules/attention.py b/deepspeech/modules/attention.py index 4401a4a55..f94797282 100644 --- a/deepspeech/modules/attention.py +++ b/deepspeech/modules/attention.py @@ -70,7 +70,7 @@ class MultiHeadedAttention(nn.Layer): paddle.Tensor: Transformed value tensor, size (#batch, n_head, time2, d_k). """ - n_batch = query.size(0) + n_batch = query.shape[0] q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) @@ -96,7 +96,7 @@ class MultiHeadedAttention(nn.Layer): paddle.Tensor: Transformed value weighted by the attention score, (#batch, time1, d_model). """ - n_batch = value.size(0) + n_batch = value.shape[0] if mask is not None: mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) scores = scores.masked_fill(mask, -float('inf')) @@ -109,8 +109,8 @@ class MultiHeadedAttention(nn.Layer): p_attn = self.dropout(attn) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) - x = x.transpose([0, 2, 1, 3]).contiguous().view( - n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) + x = x.transpose([0, 2, 1, 3]).view(n_batch, -1, self.h * + self.d_k) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) @@ -172,15 +172,16 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): paddle.Tensor: Output tensor. (batch, head, time1, time1) """ zero_pad = paddle.zeros( - (x.size(0), x.size(1), x.size(2), 1), dtype=x.dtype) + (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) x_padded = paddle.cat([zero_pad, x], dim=-1) - x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) + x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, + x.shape[2]) x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] if zero_triu: - ones = paddle.ones((x.size(2), x.size(3))) - x = x * paddle.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + ones = paddle.ones((x.shape[2], x.shape[3])) + x = x * paddle.tril(ones, x.shape[3] - x.shape[2])[None, None, :, :] return x @@ -205,7 +206,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): q, k, v = self.forward_qkv(query, key, value) q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) - n_batch_pos = pos_emb.size(0) + n_batch_pos = pos_emb.shape[0] p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) diff --git a/deepspeech/modules/conv.py b/deepspeech/modules/conv.py index 8bf48b2c8..22a168800 100644 --- a/deepspeech/modules/conv.py +++ b/deepspeech/modules/conv.py @@ -113,11 +113,9 @@ class ConvBn(nn.Layer): # reset padding part to 0 masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - # TODO(Hui Zhang): not support bool multiply - # masks = masks.type_as(x) - masks = masks.astype(x.dtype) - x = x.multiply(masks) - + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks return x, x_len diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 31e489a3d..b3ca28279 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -16,15 +16,19 @@ from paddle import nn from paddle.nn import functional as F from typeguard import check_argument_types -from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch -from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder -from deepspeech.decoders.swig_wrapper import Scorer from deepspeech.modules.loss import CTCLoss from deepspeech.utils import ctc_utils from deepspeech.utils.log import Log logger = Log(__name__).getlog() +try: + from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401 + from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401 + from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401 +except Exception as e: + logger.info("ctcdecoder not installed!") + __all__ = ['CTCDecoder'] @@ -35,7 +39,8 @@ class CTCDecoder(nn.Layer): blank_id=0, dropout_rate: float=0.0, reduction: bool=True, - batch_average: bool=True): + batch_average: bool=True, + grad_norm_type: str="instance"): """CTC decoder Args: @@ -44,6 +49,7 @@ class CTCDecoder(nn.Layer): dropout_rate (float): dropout rate (0.0 ~ 1.0) reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' batch_average (bool): do batch dim wise average. + grad_norm_type (str): one of 'instance', 'batchsize', 'frame', None. """ assert check_argument_types() super().__init__() @@ -56,7 +62,8 @@ class CTCDecoder(nn.Layer): self.criterion = CTCLoss( blank=self.blank_id, reduction=reduction_type, - batch_average=batch_average) + batch_average=batch_average, + grad_norm_type=grad_norm_type) # CTCDecoder LM Score handle self._ext_scorer = None @@ -132,7 +139,7 @@ class CTCDecoder(nn.Layer): results = [] for i, probs in enumerate(probs_split): output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=vocab_list) + probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id) results.append(output_transcription) return results @@ -212,13 +219,15 @@ class CTCDecoder(nn.Layer): num_processes=num_processes, ext_scoring_func=self._ext_scorer, cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n) + cutoff_top_n=cutoff_top_n, + blank_id=self.blank_id) results = [result[0][1] for result in beam_search_results] return results def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, decoding_method): + if decoding_method == "ctc_beam_search": self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, vocab_list) @@ -229,7 +238,7 @@ class CTCDecoder(nn.Layer): """ctc decoding with probs. Args: - probs (Tenosr): activation after softmax + probs (Tenosr): activation after softmax logits_lens (Tenosr): audio output lens vocab_list ([type]): [description] decoding_method ([type]): [description] diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index 87c9fa492..8ca72894a 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -122,11 +122,9 @@ class TransformerDecoder(nn.Layer): # tgt_mask: (B, 1, L) tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1)) # m: (1, L, L) - m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0) + m = subsequent_mask(tgt_mask.shape[-1]).unsqueeze(0) # tgt_mask: (B, L, L) - # TODO(Hui Zhang): not support & for tensor - # tgt_mask = tgt_mask & m - tgt_mask = tgt_mask.logical_and(m) + tgt_mask = tgt_mask & m x, _ = self.embed(tgt) for layer in self.decoders: @@ -137,9 +135,7 @@ class TransformerDecoder(nn.Layer): if self.use_output_layer: x = self.output_layer(x) - # TODO(Hui Zhang): reduce_sum not support bool type - # olens = tgt_mask.sum(1) - olens = tgt_mask.astype(paddle.int).sum(1) + olens = tgt_mask.sum(1) return x, olens def forward_one_step( diff --git a/deepspeech/modules/embedding.py b/deepspeech/modules/embedding.py index 98b4e1291..fbbda023c 100644 --- a/deepspeech/modules/embedding.py +++ b/deepspeech/modules/embedding.py @@ -68,7 +68,7 @@ class PositionalEncoding(nn.Layer): paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...) """ T = x.shape[1] - assert offset + x.size(1) < self.max_len + assert offset + x.shape[1] < self.max_len #TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + T] x = x * self.xscale + pos_emb @@ -114,7 +114,7 @@ class RelPositionalEncoding(PositionalEncoding): paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`). """ - assert offset + x.size(1) < self.max_len + assert offset + x.shape[1] < self.max_len x = x * self.xscale #TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + x.shape[1]] diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index 71ec61a0e..d4a8275c3 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -159,11 +159,10 @@ class BaseEncoder(nn.Layer): if self.global_cmvn is not None: xs = self.global_cmvn(xs) #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor - xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0) + xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor masks = masks.astype(paddle.bool) - #TODO(Hui Zhang): mask_pad = ~masks - mask_pad = masks.logical_not() + mask_pad = ~masks chunk_masks = add_optional_chunk_mask( xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk, decoding_chunk_size, self.static_chunk_size, @@ -207,11 +206,11 @@ class BaseEncoder(nn.Layer): chunk computation List[paddle.Tensor]: conformer cnn cache """ - assert xs.size(0) == 1 # batch size must be one + assert xs.shape[0] == 1 # batch size must be one # tmp_masks is just for interface compatibility # TODO(Hui Zhang): stride_slice not support bool tensor # tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) - tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.int32) + tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32) tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T] if self.global_cmvn is not None: @@ -221,25 +220,25 @@ class BaseEncoder(nn.Layer): xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D) if subsampling_cache is not None: - cache_size = subsampling_cache.size(1) #T + cache_size = subsampling_cache.shape[1] #T xs = paddle.cat((subsampling_cache, xs), dim=1) else: cache_size = 0 # only used when using `RelPositionMultiHeadedAttention` pos_emb = self.embed.position_encoding( - offset=offset - cache_size, size=xs.size(1)) + offset=offset - cache_size, size=xs.shape[1]) if required_cache_size < 0: next_cache_start = 0 elif required_cache_size == 0: - next_cache_start = xs.size(1) + next_cache_start = xs.shape[1] else: - next_cache_start = xs.size(1) - required_cache_size + next_cache_start = xs.shape[1] - required_cache_size r_subsampling_cache = xs[:, next_cache_start:, :] # Real mask for transformer/conformer layers - masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool) + masks = paddle.ones([1, xs.shape[1]], dtype=paddle.bool) masks = masks.unsqueeze(1) #[B=1, L'=1, T] r_elayers_output_cache = [] r_conformer_cnn_cache = [] @@ -303,7 +302,7 @@ class BaseEncoder(nn.Layer): stride = subsampling * decoding_chunk_size decoding_window = (decoding_chunk_size - 1) * subsampling + context - num_frames = xs.size(1) + num_frames = xs.shape[1] required_cache_size = decoding_chunk_size * num_decoding_left_chunks subsampling_cache: Optional[paddle.Tensor] = None elayers_output_cache: Optional[List[paddle.Tensor]] = None @@ -319,10 +318,10 @@ class BaseEncoder(nn.Layer): chunk_xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) outputs.append(y) - offset += y.size(1) + offset += y.shape[1] ys = paddle.cat(outputs, 1) # fake mask, just for jit script and compatibility with `forward` api - masks = paddle.ones([1, ys.size(1)], dtype=paddle.bool) + masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) masks = masks.unsqueeze(1) return ys, masks diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index 8918ca669..2c58be7e3 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -23,11 +23,32 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"] class CTCLoss(nn.Layer): - def __init__(self, blank=0, reduction='sum', batch_average=False): + def __init__(self, + blank=0, + reduction='sum', + batch_average=False, + grad_norm_type=None): super().__init__() # last token id as blank id self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.batch_average = batch_average + logger.info( + f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}") + + # instance for norm_by_times + # batch for norm_by_batchsize + # frame for norm_by_total_logits_len + assert grad_norm_type in ('instance', 'batch', 'frame', None) + self.norm_by_times = False + self.norm_by_batchsize = False + self.norm_by_total_logits_len = False + logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") + if grad_norm_type == 'instance': + self.norm_by_times = True + if grad_norm_type == 'batch': + self.norm_by_batchsize = True + if grad_norm_type == 'frame': + self.norm_by_total_logits_len = True def forward(self, logits, ys_pad, hlens, ys_lens): """Compute CTC loss. @@ -46,10 +67,15 @@ class CTCLoss(nn.Layer): # warp-ctc need activation with shape [T, B, V + 1] # logits: (B, L, D) -> (L, B, D) logits = logits.transpose([1, 0, 2]) - # (TODO:Hui Zhang) ctc loss does not support int64 labels ys_pad = ys_pad.astype(paddle.int32) loss = self.loss( - logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average) + logits, + ys_pad, + hlens, + ys_lens, + norm_by_times=self.norm_by_times, + norm_by_batchsize=self.norm_by_batchsize, + norm_by_total_logits_len=self.norm_by_total_logits_len) if self.batch_average: # Batch-size average loss = loss / B @@ -124,9 +150,9 @@ class LabelSmoothingLoss(nn.Layer): # use zeros_like instead of torch.no_grad() for true_dist, # since no_grad() can not be exported by JIT true_dist = paddle.full_like(x, self.smoothing / (self.size - 1)) - ignore = target == self.padding_idx # (B,) + ignore = (target == self.padding_idx) # (B,) - # target = target * (1 - ignore) # avoid -1 index + #TODO(Hui Zhang): target = target * (1 - ignore) # avoid -1 index target = target.masked_fill(ignore, 0) # avoid -1 index # true_dist.scatter_(1, target.unsqueeze(1), self.confidence) target_mask = F.one_hot(target, self.size) @@ -135,10 +161,8 @@ class LabelSmoothingLoss(nn.Layer): kl = self.criterion(F.log_softmax(x, axis=1), true_dist) - #TODO(Hui Zhang): sum not support bool type - #total = len(target) - int(ignore.sum()) - total = len(target) - int(ignore.type_as(target).sum()) + total = len(target) - int(ignore.sum()) denom = total if self.normalize_length else B - #numer = (kl * (1 - ignore)).sum() + #TODO(Hui Zhang): numer = (kl * (1 - ignore)).sum() numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum() return numer / denom diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index 05e86eb33..6d46f5ba0 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -69,8 +69,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]] """ - #TODO(Hui Zhang): return ~make_pad_mask(lengths), not support ~ - return make_pad_mask(lengths).logical_not() + return ~make_pad_mask(lengths) def subsequent_mask(size: int) -> paddle.Tensor: @@ -92,12 +91,7 @@ def subsequent_mask(size: int) -> paddle.Tensor: [1, 1, 1]] """ ret = paddle.ones([size, size], dtype=paddle.bool) - #TODO(Hui Zhang): tril not support bool - #return paddle.tril(ret) - ret = ret.astype(paddle.float) - ret = paddle.tril(ret) - ret = ret.astype(paddle.bool) - return ret + return paddle.tril(ret) def subsequent_chunk_mask( @@ -186,15 +180,13 @@ def add_optional_chunk_mask(xs: paddle.Tensor, chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size, num_left_chunks) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) - # chunk_masks = masks & chunk_masks # (B, L, L) - chunk_masks = masks.logical_and(chunk_masks) # (B, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) elif static_chunk_size > 0: num_left_chunks = num_decoding_left_chunks chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size, num_left_chunks) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) - # chunk_masks = masks & chunk_masks # (B, L, L) - chunk_masks = masks.logical_and(chunk_masks) # (B, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) else: chunk_masks = masks return chunk_masks diff --git a/deepspeech/modules/rnn.py b/deepspeech/modules/rnn.py index 0d8c9fd2c..8f8b2a18d 100644 --- a/deepspeech/modules/rnn.py +++ b/deepspeech/modules/rnn.py @@ -308,7 +308,7 @@ class RNNStack(nn.Layer): x, x_len = rnn(x, x_len) masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(-1) # [B, T, 1] - # TODO(Hui Zhang): not support bool multiply - masks = masks.astype(x.dtype) - x = x.multiply(masks) + # https://github.com/PaddlePaddle/Paddle/pull/29265 + # rhs will type promote to lhs + x = x * masks return x, x_len diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index 7f4bb8048..e079293c7 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -14,6 +14,20 @@ import argparse +class ExtendAction(argparse.Action): + """ + [Since Python 3.8, the "extend" is available directly in stdlib] + (https://docs.python.org/3.8/library/argparse.html#action). + If you only have to support 3.8+ then defining it yourself is no longer required. + Usage of stdlib "extend" action is exactly the same way as this answer originally described: + """ + + def __call__(self, parser, namespace, values, option_string=None): + items = getattr(namespace, self.dest) or [] + items.extend(values) + setattr(namespace, self.dest, items) + + def default_argument_parser(): r"""A simple yet genral argument parser for experiments with parakeet. @@ -30,7 +44,7 @@ def default_argument_parser(): The ``--checkpoint_path`` specifies the checkpoint to load from. - The ``--device`` and ``--nprocs`` specifies how to run the training. + The ``--nprocs`` specifies how to run the training. See Also @@ -42,29 +56,53 @@ def default_argument_parser(): the parser """ parser = argparse.ArgumentParser() + parser.register('action', 'extend', ExtendAction) - # yapf: disable - # data and output - parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") - parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") - parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") - - # load from saved checkpoint - parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") - - # running - parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], - help="device type to use, cpu and gpu are supported.") - parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") - - # overwrite extra config and default config - # parser.add_argument("--opts", nargs=argparse.REMAINDER, - # help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") - parser.add_argument("--opts", type=str, default=[], nargs='+', - help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + train_group = parser.add_argument_group( + title='Train Options', description=None) + train_group.add_argument( + "--seed", + type=int, + default=None, + help="seed to use for paddle, np and random. None or 0 for random, else set seed." + ) + train_group.add_argument( + "--nprocs", + type=int, + default=1, + help="number of parallel processes. 0 for cpu.") + train_group.add_argument( + "--config", metavar="CONFIG_FILE", help="config file.") + train_group.add_argument( + "--output", metavar="CKPT_DIR", help="path to save checkpoint.") + train_group.add_argument( + "--checkpoint_path", type=str, help="path to load checkpoint") + train_group.add_argument( + "--opts", + action='extend', + nargs=2, + metavar=('key', 'val'), + help="overwrite --config field, passing (KEY VALUE) pairs") + train_group.add_argument( + "--dump-config", metavar="FILE", help="dump config to `this` file.") - parser.add_argument("--seed", type=int, default=None, - help="seed to use for paddle, np and random. None or 0 for random, else set seed.") - # yapd: enable + profile_group = parser.add_argument_group( + title='Benchmark Options', description=None) + profile_group.add_argument( + '--profiler-options', + type=str, + default=None, + help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' + ) + profile_group.add_argument( + '--benchmark-batch-size', + type=int, + default=None, + help='batch size for benchmark.') + profile_group.add_argument( + '--benchmark-max-step', + type=int, + default=None, + help='max iteration for benchmark.') return parser diff --git a/deepspeech/training/extensions/evaluator.py b/deepspeech/training/extensions/evaluator.py index 96ff967f5..1026a4ec3 100644 --- a/deepspeech/training/extensions/evaluator.py +++ b/deepspeech/training/extensions/evaluator.py @@ -13,14 +13,18 @@ # limitations under the License. from typing import Dict -import extension import paddle +from paddle import distributed as dist from paddle.io import DataLoader from paddle.nn import Layer +from . import extension from ..reporter import DictSummary +from ..reporter import ObsScope from ..reporter import report -from ..reporter import scope +from ..timer import Timer +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() class StandardEvaluator(extension.Extension): @@ -43,6 +47,27 @@ class StandardEvaluator(extension.Extension): def evaluate_core(self, batch): # compute self.model(batch) # you may report here + return + + def evaluate_sync(self, data): + # dist sync `evaluate_core` outputs + if data is None: + return + + numerator, denominator = data + if dist.get_world_size() > 1: + numerator = paddle.to_tensor(numerator) + denominator = paddle.to_tensor(denominator) + # the default operator in all_reduce function is sum. + dist.all_reduce(numerator) + dist.all_reduce(denominator) + value = numerator / denominator + value = float(value) + else: + value = numerator / denominator + # used for `snapshort` to do kbest save. + report("VALID/LOSS", value) + logger.info(f"Valid: all-reduce loss {value}") def evaluate(self): # switch to eval mode @@ -53,12 +78,16 @@ class StandardEvaluator(extension.Extension): summary = DictSummary() for batch in self.dataloader: observation = {} - with scope(observation): + with ObsScope(observation): # main evaluation computation here. with paddle.no_grad(): - self.evaluate_core(batch) + self.evaluate_sync(self.evaluate_core(batch)) summary.add(observation) summary = summary.compute_mean() + + # switch to train mode + for model in self.models.values(): + model.train() return summary def __call__(self, trainer=None): @@ -66,6 +95,7 @@ class StandardEvaluator(extension.Extension): # if it is used to extend a trainer, the metrics is reported to # to observation of the trainer # or otherwise, you can use your own observation - summary = self.evaluate() + with Timer("Eval Time Cost: {}"): + summary = self.evaluate() for k, v in summary.items(): report(k, v) diff --git a/deepspeech/training/extensions/snapshot.py b/deepspeech/training/extensions/snapshot.py index cb4e6dfbf..e81eb97fc 100644 --- a/deepspeech/training/extensions/snapshot.py +++ b/deepspeech/training/extensions/snapshot.py @@ -20,8 +20,9 @@ from typing import List import jsonlines -from deepspeech.training.extensions import extension -from deepspeech.training.updaters.trainer import Trainer +from . import extension +from ..reporter import get_observations +from ..updaters.trainer import Trainer from deepspeech.utils.log import Log from deepspeech.utils.mp_tools import rank_zero_only @@ -52,8 +53,19 @@ class Snapshot(extension.Extension): priority = -100 default_name = "snapshot" - def __init__(self, max_size: int=5, snapshot_on_error: bool=False): + def __init__(self, + mode='latest', + max_size: int=5, + indicator=None, + less_better=True, + snapshot_on_error: bool=False): self.records: List[Dict[str, Any]] = [] + assert mode in ('latest', 'kbest'), mode + if mode == 'kbest': + assert indicator is not None + self.mode = mode + self.indicator = indicator + self.less_is_better = less_better self.max_size = max_size self._snapshot_on_error = snapshot_on_error self._save_all = (max_size == -1) @@ -66,16 +78,17 @@ class Snapshot(extension.Extension): # load existing records record_path: Path = self.checkpoint_dir / "records.jsonl" if record_path.exists(): - logger.debug("Loading from an existing checkpoint dir") self.records = load_records(record_path) - trainer.updater.load(self.records[-1]['path']) + ckpt_path = self.records[-1]['path'] + logger.info(f"Loading from an existing checkpoint {ckpt_path}") + trainer.updater.load(ckpt_path) def on_error(self, trainer, exc, tb): if self._snapshot_on_error: - self.save_checkpoint_and_update(trainer) + self.save_checkpoint_and_update(trainer, 'latest') def __call__(self, trainer: Trainer): - self.save_checkpoint_and_update(trainer) + self.save_checkpoint_and_update(trainer, self.mode) def full(self): """Whether the number of snapshots it keeps track of is greater @@ -83,12 +96,12 @@ class Snapshot(extension.Extension): return (not self._save_all) and len(self.records) > self.max_size @rank_zero_only - def save_checkpoint_and_update(self, trainer: Trainer): + def save_checkpoint_and_update(self, trainer: Trainer, mode: str): """Saving new snapshot and remove the oldest snapshot if needed.""" iteration = trainer.updater.state.iteration epoch = trainer.updater.state.epoch num = epoch if self.trigger[1] == 'epoch' else iteration - path = self.checkpoint_dir / f"{num}.pdz" + path = self.checkpoint_dir / f"{num}.np" # add the new one trainer.updater.save(path) @@ -97,11 +110,17 @@ class Snapshot(extension.Extension): 'path': str(path.resolve()), # use absolute path 'iteration': iteration, 'epoch': epoch, + 'indicator': get_observations()[self.indicator] } self.records.append(record) # remove the earist if self.full(): + if mode == 'kbest': + self.records = sorted( + self.records, + key=lambda record: record['indicator'], + reverse=not self.less_is_better) eariest_record = self.records[0] os.remove(eariest_record["path"]) self.records.pop(0) diff --git a/deepspeech/training/extensions/visualizer.py b/deepspeech/training/extensions/visualizer.py index b69e94aaf..e5f456cac 100644 --- a/deepspeech/training/extensions/visualizer.py +++ b/deepspeech/training/extensions/visualizer.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from deepspeech.training.extensions import extension -from deepspeech.training.updaters.trainer import Trainer +from visualdl import LogWriter + +from . import extension +from ..updaters.trainer import Trainer class VisualDL(extension.Extension): @@ -26,8 +28,8 @@ class VisualDL(extension.Extension): default_name = 'visualdl' priority = extension.PRIORITY_READER - def __init__(self, writer): - self.writer = writer + def __init__(self, output_dir): + self.writer = LogWriter(str(output_dir)) def __call__(self, trainer: Trainer): for k, v in trainer.observation.items(): diff --git a/deepspeech/training/gradclip.py b/deepspeech/training/gradclip.py index f46814eb0..87b36acae 100644 --- a/deepspeech/training/gradclip.py +++ b/deepspeech/training/gradclip.py @@ -47,7 +47,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): sum_square = layers.reduce_sum(square) sum_square_list.append(sum_square) - # debug log + # debug log, not dump all since slow down train process if i < 10: logger.debug( f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") @@ -76,7 +76,7 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): new_grad = layers.elementwise_mul(x=g, y=clip_var) params_and_grads.append((p, new_grad)) - # debug log + # debug log, not dump all since slow down train process if i < 10: logger.debug( f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" diff --git a/deepspeech/training/reporter.py b/deepspeech/training/reporter.py index 66a81adef..7afc33f38 100644 --- a/deepspeech/training/reporter.py +++ b/deepspeech/training/reporter.py @@ -19,7 +19,7 @@ OBSERVATIONS = None @contextlib.contextmanager -def scope(observations): +def ObsScope(observations): # make `observation` the target to report to. # it is basically a dictionary that stores temporary observations global OBSERVATIONS diff --git a/deepspeech/training/timer.py b/deepspeech/training/timer.py new file mode 100644 index 000000000..2ca9d6386 --- /dev/null +++ b/deepspeech/training/timer.py @@ -0,0 +1,50 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import time + +from deepspeech.utils.log import Log + +__all__ = ["Timer"] + +logger = Log(__name__).getlog() + + +class Timer(): + """To be used like this: + with Timer("Message") as value: + do some thing + """ + + def __init__(self, message=None): + self.message = message + + def duration(self) -> str: + elapsed_time = time.time() - self.start + time_str = str(datetime.timedelta(seconds=elapsed_time)) + return time_str + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, type, value, traceback): + if self.message: + logger.info(self.message.format(self.duration())) + + def __call__(self) -> float: + return time.time() - self.start + + def __str__(self): + return self.duration() diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 3a922c6f4..79b1562e4 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -11,17 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys import time +from collections import OrderedDict from pathlib import Path import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter +from deepspeech.training.reporter import ObsScope +from deepspeech.training.reporter import report +from deepspeech.training.timer import Timer from deepspeech.utils import mp_tools +from deepspeech.utils import profiler from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log from deepspeech.utils.utility import seed_all +from deepspeech.utils.utility import UpdateConfig __all__ = ["Trainer"] @@ -79,7 +86,7 @@ class Trainer(): >>> config.merge_from_list(args.opts) >>> config.freeze() >>> - >>> if args.nprocs > 1 and args.device == "gpu": + >>> if args.nprocs > 0: >>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) >>> else: >>> main_sp(config, args) @@ -94,15 +101,25 @@ class Trainer(): self.checkpoint_dir = None self.iteration = 0 self.epoch = 0 + self.rank = dist.get_rank() + + logger.info(f"Rank: {self.rank}/{dist.get_world_size()}") if args.seed: seed_all(args.seed) logger.info(f"Set seed {args.seed}") + if self.args.benchmark_batch_size: + with UpdateConfig(self.config): + self.config.collator.batch_size = self.args.benchmark_batch_size + self.config.training.log_interval = 1 + logger.info( + f"Benchmark reset batch-size: {self.args.benchmark_batch_size}") + def setup(self): """Setup the experiment. """ - paddle.set_device(self.args.device) + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') if self.parallel: self.init_parallel() @@ -122,7 +139,7 @@ class Trainer(): """A flag indicating whether the experiment should run with multiprocessing. """ - return self.args.device == "gpu" and self.args.nprocs > 1 + return self.args.nprocs > 0 def init_parallel(self): """Init environment for multiprocess training. @@ -162,67 +179,108 @@ class Trainer(): checkpoint_dir=self.checkpoint_dir, checkpoint_path=self.args.checkpoint_path) if infos: - # restore from ckpt + # just restore ckpt + # lr will resotre from optimizer ckpt self.iteration = infos["step"] self.epoch = infos["epoch"] scratch = False + logger.info( + f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!") else: self.iteration = 0 self.epoch = 0 scratch = True - + logger.info("Init from scratch!") return scratch - def new_epoch(self): - """Reset the train loader seed and increment `epoch`. - """ - self.epoch += 1 - if self.parallel and hasattr(self.train_loader, "batch_sampler"): + def maybe_batch_sampler_step(self): + """ batch_sampler seed by epoch """ + if hasattr(self.train_loader, "batch_sampler"): batch_sampler = self.train_loader.batch_sampler if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): batch_sampler.set_epoch(self.epoch) - def train(self): - """The training process control by epoch.""" + def before_train(self): from_scratch = self.resume_or_scratch() if from_scratch: - # save init model, i.e. 0 epoch + # scratch: save init model, i.e. 0 epoch self.save(tag='init', infos=None) - self.lr_scheduler.step(self.epoch) - if self.parallel and hasattr(self.train_loader, "batch_sampler"): - self.train_loader.batch_sampler.set_epoch(self.epoch) + else: + # resume: train next_epoch and next_iteration + self.epoch += 1 + self.iteration += 1 + logger.info( + f"Resume train: epoch {self.epoch }, step {self.iteration}!") + + self.maybe_batch_sampler_step() + + def new_epoch(self): + """Reset the train loader seed and increment `epoch`. + """ + # `iteration` increased by train step + self.epoch += 1 + self.maybe_batch_sampler_step() + + def after_train_batch(self): + if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step: + profiler.add_profiler_step(self.args.profiler_options) + logger.info( + f"Reach benchmark-max-step: {self.args.benchmark_max_step}") + sys.exit( + f"Reach benchmark-max-step: {self.args.benchmark_max_step}") + + def train(self): + """The training process control by epoch.""" + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train:" + observation = OrderedDict() + with ObsScope(observation): + report("Rank", dist.get_rank()) + report("epoch", self.epoch) + report('step', self.iteration) + report("lr", self.lr_scheduler()) + self.train_batch(batch_index, batch, msg) + self.after_train_batch() + report('iter', batch_index + 1) + report('total', len(self.train_loader)) + report('reader_cost', dataload_time) + observation['batch_cost'] = observation[ + 'reader_cost'] + observation['step_cost'] + observation['samples'] = observation['batch_size'] + observation['ips[sent./sec]'] = observation[ + 'batch_size'] / observation['batch_cost'] + for k, v in observation.items(): + msg += f" {k}: " + msg += f"{v:>.8f}" if isinstance(v, + float) else f"{v}" + msg += "," + logger.info(msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) @@ -231,6 +289,7 @@ class Trainer(): 'epoch', {'cv_loss': cv_loss, 'lr': self.lr_scheduler()}, self.epoch) + # after epoch self.save(tag=self.epoch, infos={'val_loss': cv_loss}) # step lr every epoch self.lr_scheduler.step() @@ -240,14 +299,13 @@ class Trainer(): """The routine of the experiment after setup. This method is intended to be used by the user. """ - try: - self.train() - except KeyboardInterrupt: - self.save() - exit(-1) - finally: - self.destory() - logger.info("Training Done.") + with Timer("Training Done: {}"): + try: + self.train() + except KeyboardInterrupt: + exit(-1) + finally: + self.destory() def setup_output_dir(self): """Create a directory used for output. diff --git a/deepspeech/training/updaters/standard_updater.py b/deepspeech/training/updaters/standard_updater.py index fc758e93e..10c99e7fc 100644 --- a/deepspeech/training/updaters/standard_updater.py +++ b/deepspeech/training/updaters/standard_updater.py @@ -14,12 +14,12 @@ from typing import Dict from typing import Optional -from paddle import Tensor +import paddle from paddle.io import DataLoader from paddle.io import DistributedBatchSampler from paddle.nn import Layer from paddle.optimizer import Optimizer -from timer import timer +from paddle.optimizer.lr import LRScheduler from deepspeech.training.reporter import report from deepspeech.training.updaters.updater import UpdaterBase @@ -39,8 +39,10 @@ class StandardUpdater(UpdaterBase): def __init__(self, model: Layer, optimizer: Optimizer, + scheduler: LRScheduler, dataloader: DataLoader, init_state: Optional[UpdaterState]=None): + super().__init__(init_state) # it is designed to hold multiple models models = {"main": model} self.models: Dict[str, Layer] = models @@ -51,15 +53,14 @@ class StandardUpdater(UpdaterBase): self.optimizer = optimizer self.optimizers: Dict[str, Optimizer] = optimizers + # it is designed to hold multiple scheduler + schedulers = {"main": scheduler} + self.scheduler = scheduler + self.schedulers: Dict[str, LRScheduler] = schedulers + # dataloaders self.dataloader = dataloader - # init state - if init_state is None: - self.state = UpdaterState() - else: - self.state = init_state - self.train_iterator = iter(dataloader) def update(self): @@ -103,8 +104,10 @@ class StandardUpdater(UpdaterBase): model.train() # training for a step is implemented here - batch = self.read_batch() - self.update_core(batch) + with Timier("data time cost:{}"): + batch = self.read_batch() + with Timier("step time cost:{}"): + self.update_core(batch) self.state.iteration += 1 if self.updates_per_epoch is not None: @@ -115,13 +118,14 @@ class StandardUpdater(UpdaterBase): """A simple case for a training step. Basic assumptions are: Single model; Single optimizer; + Single scheduler, and update learning rate each step; A batch from the dataloader is just the input of the model; The model return a single loss, or a dict containing serval losses. Parameters updates at every batch, no gradient accumulation. """ loss = self.model(*batch) - if isinstance(loss, Tensor): + if isinstance(loss, paddle.Tensor): loss_dict = {"main": loss} else: # Dict[str, Tensor] @@ -135,14 +139,15 @@ class StandardUpdater(UpdaterBase): for name, loss_item in loss_dict.items(): report(name, float(loss_item)) - self.optimizer.clear_gradient() + self.optimizer.clear_grad() loss_dict["main"].backward() - self.optimizer.update() + self.optimizer.step() + self.scheduler.step() @property def updates_per_epoch(self): - """Number of updater per epoch, determined by the length of the - dataloader.""" + """Number of steps per epoch, + determined by the length of the dataloader.""" length_of_dataloader = None try: length_of_dataloader = len(self.dataloader) @@ -163,18 +168,16 @@ class StandardUpdater(UpdaterBase): def read_batch(self): """Read a batch from the data loader, auto renew when data is exhausted.""" - with timer() as t: - try: - batch = next(self.train_iterator) - except StopIteration: - self.new_epoch() - batch = next(self.train_iterator) - logger.debug( - f"Read a batch takes {t.elapse}s.") # replace it with logger + try: + batch = next(self.train_iterator) + except StopIteration: + self.new_epoch() + batch = next(self.train_iterator) return batch def state_dict(self): - """State dict of a Updater, model, optimizer and updater state are included.""" + """State dict of a Updater, model, optimizers/schedulers + and updater state are included.""" state_dict = super().state_dict() for name, model in self.models.items(): state_dict[f"{name}_params"] = model.state_dict() @@ -184,7 +187,7 @@ class StandardUpdater(UpdaterBase): def set_state_dict(self, state_dict): """Set state dict for a Updater. Parameters of models, states for - optimizers and UpdaterState are restored.""" + optimizers/schedulers and UpdaterState are restored.""" for name, model in self.models.items(): model.set_state_dict(state_dict[f"{name}_params"]) for name, optim in self.optimizers.items(): diff --git a/deepspeech/training/updaters/trainer.py b/deepspeech/training/updaters/trainer.py index 954ce2604..077694659 100644 --- a/deepspeech/training/updaters/trainer.py +++ b/deepspeech/training/updaters/trainer.py @@ -24,7 +24,7 @@ import tqdm from deepspeech.training.extensions.extension import Extension from deepspeech.training.extensions.extension import PRIORITY_READER -from deepspeech.training.reporter import scope +from deepspeech.training.reporter import ObsScope from deepspeech.training.triggers import get_trigger from deepspeech.training.triggers.limit_trigger import LimitTrigger from deepspeech.training.updaters.updater import UpdaterBase @@ -140,11 +140,11 @@ class Trainer(): try: while not stop_trigger(self): self.observation = {} - # set observation as the report target - # you can use report freely in Updater.update() + # set observation as the `report` target + # you can use `report` freely in Updater.update() # updating parameters and state - with scope(self.observation): + with ObsScope(self.observation): update() p.update() diff --git a/deepspeech/training/updaters/updater.py b/deepspeech/training/updaters/updater.py index 66fdc2bbc..e5dd65563 100644 --- a/deepspeech/training/updaters/updater.py +++ b/deepspeech/training/updaters/updater.py @@ -52,6 +52,7 @@ class UpdaterBase(): """ def __init__(self, init_state=None): + # init state if init_state is None: self.state = UpdaterState() else: diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index a59f8be79..8e31edfae 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -114,13 +114,13 @@ class Checkpoint(): params_path = checkpoint_path + ".pdparams" model_dict = paddle.load(params_path) model.set_state_dict(model_dict) - logger.info("Rank {}: loaded model from {}".format(rank, params_path)) + logger.info("Rank {}: Restore model from {}".format(rank, params_path)) optimizer_path = checkpoint_path + ".pdopt" if optimizer and os.path.isfile(optimizer_path): optimizer_dict = paddle.load(optimizer_path) optimizer.set_state_dict(optimizer_dict) - logger.info("Rank {}: loaded optimizer state from {}".format( + logger.info("Rank {}: Restore optimizer state from {}".format( rank, optimizer_path)) info_path = re.sub('.pdparams$', '.json', params_path) diff --git a/deepspeech/utils/ctc_utils.py b/deepspeech/utils/ctc_utils.py index 09543d48d..fc43a71f0 100644 --- a/deepspeech/utils/ctc_utils.py +++ b/deepspeech/utils/ctc_utils.py @@ -84,19 +84,19 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, y_insert_blank = insert_blank(y, blank_id) #(2L+1) log_alpha = paddle.zeros( - (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) + (ctc_probs.shape[0], len(y_insert_blank))) #(T, 2L+1) log_alpha = log_alpha - float('inf') # log of zero - # TODO(Hui Zhang): zeros not support paddle.int16 + + # self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16 state_path = (paddle.zeros( - (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1 + (ctc_probs.shape[0], len(y_insert_blank)), dtype=paddle.int32) - 1 ) # state path, Tuple((T, 2L+1)) # init start state - # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 - log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb - log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb + log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # State-b, Sb + log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # State-nb, Snb - for t in range(1, ctc_probs.size(0)): # T + for t in range(1, ctc_probs.shape[0]): # T for s in range(len(y_insert_blank)): # 2L+1 if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ s] == y_insert_blank[s - 2]: @@ -110,13 +110,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, log_alpha[t - 1, s - 2], ]) prev_state = [s, s - 1, s - 2] - # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 - log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int( - y_insert_blank[s])] + log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ + y_insert_blank[s]] state_path[t, s] = prev_state[paddle.argmax(candidates)] - - # TODO(Hui Zhang): zeros not support paddle.int16 - state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32) + # self.__setitem_varbase__(item, value) When assign a value to a paddle.Tensor, the data type of the paddle.Tensor not support int16 + state_seq = -1 * paddle.ones((ctc_probs.shape[0], 1), dtype=paddle.int32) candidates = paddle.to_tensor([ log_alpha[-1, len(y_insert_blank) - 1], # Sb @@ -124,11 +122,11 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, ]) prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] state_seq[-1] = prev_state[paddle.argmax(candidates)] - for t in range(ctc_probs.size(0) - 2, -1, -1): + for t in range(ctc_probs.shape[0] - 2, -1, -1): state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] output_alignment = [] - for t in range(0, ctc_probs.size(0)): + for t in range(0, ctc_probs.shape[0]): output_alignment.append(y_insert_blank[state_seq[t, 0]]) return output_alignment diff --git a/deepspeech/utils/log.py b/deepspeech/utils/log.py index 3fd7d2480..7e8de600a 100644 --- a/deepspeech/utils/log.py +++ b/deepspeech/utils/log.py @@ -12,19 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import getpass -import logging import os import socket import sys +from loguru import logger from paddle import inference -FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' -DATE_FMT_STR = '%Y/%m/%d %H:%M:%S' - -logging.basicConfig( - level=logging.DEBUG, format=FORMAT_STR, datefmt=DATE_FMT_STR) - def find_log_dir(log_dir=None): """Returns the most suitable directory to put log files into. @@ -98,59 +92,28 @@ def find_log_dir_and_names(program_name=None, log_dir=None): class Log(): - - log_name = None - - def __init__(self, logger=None): - self.logger = logging.getLogger(logger) - self.logger.setLevel(logging.DEBUG) - - file_dir = os.getcwd() + '/log' - if not os.path.exists(file_dir): - os.mkdir(file_dir) - self.log_dir = file_dir - - actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names( - program_name=None, log_dir=self.log_dir) - - basename = '%s.DEBUG.%d' % (file_prefix, os.getpid()) - filename = os.path.join(actual_log_dir, basename) - if Log.log_name is None: - Log.log_name = filename - - # Create a symlink to the log file with a canonical name. - symlink = os.path.join(actual_log_dir, symlink_prefix + '.DEBUG') - try: - if os.path.islink(symlink): - os.unlink(symlink) - os.symlink(os.path.basename(Log.log_name), symlink) - except EnvironmentError: - # If it fails, we're sad but it's no error. Commonly, this - # fails because the symlink was created by another user and so - # we can't modify it - pass - - if not self.logger.hasHandlers(): - formatter = logging.Formatter(fmt=FORMAT_STR, datefmt=DATE_FMT_STR) - fh = logging.FileHandler(Log.log_name) - fh.setLevel(logging.DEBUG) - fh.setFormatter(formatter) - self.logger.addHandler(fh) - - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - self.logger.addHandler(ch) - - # stop propagate for propagating may print - # log multiple times - self.logger.propagate = False + """Default Logger for all.""" + logger.remove() + logger.add( + sys.stdout, + level='INFO', + enqueue=True, + filter=lambda record: record['level'].no >= 20) + _, file_prefix, _ = find_log_dir_and_names() + sink_prefix = os.path.join("exp/log", file_prefix) + sink_path = sink_prefix[:-3] + "{time}.log" + logger.add(sink_path, level='DEBUG', enqueue=True, rotation="500 MB") + + def __init__(self, name=None): + pass def getlog(self): - return self.logger + return logger class Autolog: + """Just used by fullchain project""" + def __init__(self, batch_size, model_name="DeepSpeech", diff --git a/deepspeech/utils/profiler.py b/deepspeech/utils/profiler.py new file mode 100644 index 000000000..5733f8ed5 --- /dev/null +++ b/deepspeech/utils/profiler.py @@ -0,0 +1,119 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +import paddle + +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + +# A global variable to record the number of calling times for profiler +# functions. It is used to specify the tracing range of training steps. +_profiler_step_id = 0 + +# A global variable to avoid parsing from string every time. +_profiler_options = None + + +class ProfilerOptions(object): + ''' + Use a string to initialize a ProfilerOptions. + The string should be in the format: "key1=value1;key2=value;key3=value3". + For example: + "profile_path=model.profile" + "batch_range=[50, 60]; profile_path=model.profile" + "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile" + ProfilerOptions supports following key-value pair: + batch_range - a integer list, e.g. [100, 110]. + state - a string, the optional values are 'CPU', 'GPU' or 'All'. + sorted_key - a string, the optional values are 'calls', 'total', + 'max', 'min' or 'ave. + tracer_option - a string, the optional values are 'Default', 'OpDetail', + 'AllOpDetail'. + profile_path - a string, the path to save the serialized profile data, + which can be used to generate a timeline. + exit_on_finished - a boolean. + ''' + + def __init__(self, options_str): + assert isinstance(options_str, str) + + self._options = { + 'batch_range': [10, 20], + 'state': 'All', + 'sorted_key': 'total', + 'tracer_option': 'Default', + 'profile_path': '/tmp/profile', + 'exit_on_finished': True + } + self._parse_from_string(options_str) + + def _parse_from_string(self, options_str): + if not options_str: + return + + for kv in options_str.replace(' ', '').split(';'): + key, value = kv.split('=') + if key == 'batch_range': + value_list = value.replace('[', '').replace(']', '').split(',') + value_list = list(map(int, value_list)) + if len(value_list) >= 2 and value_list[0] >= 0 and value_list[ + 1] > value_list[0]: + self._options[key] = value_list + elif key == 'exit_on_finished': + self._options[key] = value.lower() in ("yes", "true", "t", "1") + elif key in [ + 'state', 'sorted_key', 'tracer_option', 'profile_path' + ]: + self._options[key] = value + + def __getitem__(self, name): + if self._options.get(name, None) is None: + raise ValueError( + "ProfilerOptions does not have an option named %s." % name) + return self._options[name] + + +def add_profiler_step(options_str=None): + ''' + Enable the operator-level timing using PaddlePaddle's profiler. + The profiler uses a independent variable to count the profiler steps. + One call of this function is treated as a profiler step. + + Args: + profiler_options - a string to initialize the ProfilerOptions. + Default is None, and the profiler is disabled. + ''' + if options_str is None: + return + + global _profiler_step_id + global _profiler_options + + if _profiler_options is None: + _profiler_options = ProfilerOptions(options_str) + logger.info(f"Profiler: {options_str}") + logger.info(f"Profiler: {_profiler_options._options}") + + if _profiler_step_id == _profiler_options['batch_range'][0]: + paddle.utils.profiler.start_profiler(_profiler_options['state'], + _profiler_options['tracer_option']) + elif _profiler_step_id == _profiler_options['batch_range'][1]: + paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'], + _profiler_options['profile_path']) + if _profiler_options['exit_on_finished']: + sys.exit(0) + + _profiler_step_id += 1 diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 9bff6b0f3..61798816b 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -83,7 +83,7 @@ def pad_sequence(sequences: List[paddle.Tensor], # (TODO Hui Zhang): slice not supprot `end==start` # trailing_dims = max_size[1:] trailing_dims = max_size[1:] if max_size.ndim >= 2 else () - max_len = max([s.size(0) for s in sequences]) + max_len = max([s.shape[0] for s in sequences]) if batch_first: out_dims = (len(sequences), max_len) + trailing_dims else: @@ -91,12 +91,22 @@ def pad_sequence(sequences: List[paddle.Tensor], out_tensor = sequences[0].new_full(out_dims, padding_value) for i, tensor in enumerate(sequences): - length = tensor.size(0) + length = tensor.shape[0] # use index notation to prevent duplicate references to the tensor if batch_first: - out_tensor[i, :length, ...] = tensor + # TODO (Hui Zhang): set_value op not supprot `end==start` + # out_tensor[i, :length, ...] = tensor + if length != 0: + out_tensor[i, :length, ...] = tensor + else: + out_tensor[i, length, ...] = tensor else: - out_tensor[:length, i, ...] = tensor + # TODO (Hui Zhang): set_value op not supprot `end==start` + # out_tensor[:length, i, ...] = tensor + if length != 0: + out_tensor[:length, i, ...] = tensor + else: + out_tensor[length, i, ...] = tensor return out_tensor @@ -139,7 +149,7 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id) - B = ys_pad.size(0) + B = ys_pad.shape[0] _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos ys_in = paddle.cat([_sos, ys_pad], dim=1) @@ -165,16 +175,10 @@ def th_accuracy(pad_outputs: paddle.Tensor, Returns: float: Accuracy value (0.0 - 1.0). """ - pad_pred = pad_outputs.view( - pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2) + pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], + pad_outputs.shape[1]).argmax(2) mask = pad_targets != ignore_label - #TODO(Hui Zhang): sum not support bool type - # numerator = paddle.sum( - # pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) - numerator = ( + numerator = paddle.sum( pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) - numerator = paddle.sum(numerator.type_as(pad_targets)) - #TODO(Hui Zhang): sum not support bool type - # denominator = paddle.sum(mask) - denominator = paddle.sum(mask.type_as(pad_targets)) + denominator = paddle.sum(mask) return float(numerator) / float(denominator) diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index e18fc1f77..6f84c41be 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -16,15 +16,27 @@ import distutils.util import math import os import random +from contextlib import contextmanager from typing import List import numpy as np import paddle -__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"] +__all__ = [ + "UpdateConfig", "seed_all", 'print_arguments', 'add_arguments', "log_add" +] + + +@contextmanager +def UpdateConfig(config): + """Update yacs config""" + config.defrost() + yield + config.freeze() def seed_all(seed: int=210329): + """freeze random generator seed.""" np.random.seed(seed) random.seed(seed) paddle.seed(seed) diff --git a/doc/images/multi_gpu_speedup.png b/doc/images/multi_gpu_speedup.png deleted file mode 100755 index 286de5151..000000000 Binary files a/doc/images/multi_gpu_speedup.png and /dev/null differ diff --git a/doc/images/tuning_error_surface.png b/doc/images/tuning_error_surface.png deleted file mode 100644 index 2204cee2f..000000000 Binary files a/doc/images/tuning_error_surface.png and /dev/null differ diff --git a/doc/src/benchmark.md b/doc/src/benchmark.md deleted file mode 100644 index 9c1c86fd7..000000000 --- a/doc/src/benchmark.md +++ /dev/null @@ -1,16 +0,0 @@ -# Benchmarks - -## Acceleration with Multi-GPUs - -We compare the training time with 1, 2, 4, 8 Tesla V100 GPUs (with a subset of LibriSpeech samples whose audio durations are between 6.0 and 7.0 seconds). And it shows that a **near-linear** acceleration with multiple GPUs has been achieved. In the following figure, the time (in seconds) cost for training is printed on the blue bars. - - - -| # of GPU | Acceleration Rate | -| -------- | --------------: | -| 1 | 1.00 X | -| 2 | 1.98 X | -| 4 | 3.73 X | -| 8 | 6.95 X | - -`utils/profile.sh` provides such a demo profiling tool, you can change it as need. diff --git a/doc/src/faq.md b/doc/src/faq.md deleted file mode 100644 index e29428176..000000000 --- a/doc/src/faq.md +++ /dev/null @@ -1,37 +0,0 @@ -# FAQ - -1. 音频变速快慢到达什么晨读会影响识别率? - - 变速会提升识别效果,一般用0.9, 1.0, 1.1 的变速。 - -2. 音量大小到什么程度会影响识别率? - - 一般训练会固定音量到一个范围内,波动过大会影响训练,估计在10dB ~ 20dB吧。 - -3. 语音模型训练数据的最小时长要求时多少? - - Aishell-1大约178h的数据,数据越多越好。 - -4. 那些噪声或背景生会影响识别率? - - 主要是人生干扰和低信噪比会影响识别率。 - -5. 单条语音数据的长度限制是多少? - - 一般训练的语音长度会限制在1s~6s之间,和训练配置有关。 - -6. 背景声在识别前是否需要分离出来,或做降噪处理? - - 需要分离的,需要结合具体场景考虑。 - -7. 模型是否带有VAD人生激活识别能力? - - VAD是单独的模型或模块,模型不包含此能力。 - -8. 是否支持长语音识别? - - 一般过VAD后识别。 - -9. Mandarin LM Large语言模型需要的硬件配置时怎样的? - - 内存能放得下LM即可。 diff --git a/doc/src/reference.md b/doc/src/reference.md deleted file mode 100644 index 69ff6ab88..000000000 --- a/doc/src/reference.md +++ /dev/null @@ -1,3 +0,0 @@ -# Reference - -* [wenet](https://github.com/mobvoi/wenet) diff --git a/doc/src/released_model.md b/doc/src/released_model.md deleted file mode 100644 index 0919bba58..000000000 --- a/doc/src/released_model.md +++ /dev/null @@ -1,9 +0,0 @@ -# Released Models - -## Language Model Released - -Language Model | Training Data | Token-based | Size | Descriptions -:-------------:| :------------:| :-----: | -----: | :----------------- -[English LM](https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm) | [CommonCrawl(en.00)](http://web-language-models.s3-website-us-east-1.amazonaws.com/ngrams/en/deduped/en.00.deduped.xz) | Word-based | 8.3 GB | Pruned with 0 1 1 1 1;
About 1.85 billion n-grams;
'trie' binary with '-a 22 -q 8 -b 8' -[Mandarin LM Small](https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm) | Baidu Internal Corpus | Char-based | 2.8 GB | Pruned with 0 1 2 4 4;
About 0.13 billion n-grams;
'probing' binary with default settings -[Mandarin LM Large](https://deepspeech.bj.bcebos.com/zh_lm/zhidao_giga.klm) | Baidu Internal Corpus | Char-based | 70.4 GB | No Pruning;
About 3.7 billion n-grams;
'probing' binary with default settings diff --git a/doc/src/server.md b/doc/src/server.md deleted file mode 100644 index 4918d5ebe..000000000 --- a/doc/src/server.md +++ /dev/null @@ -1,34 +0,0 @@ - -# Trying Live Demo with Your Own Voice - -Until now, an ASR model is trained and tested qualitatively (`infer`) and quantitatively (`test`) with existing audio files. But it is not yet tested with your own speech. We build up a real-time demo ASR engine with the trained model, enabling you to test and play around with the demo, with your own voice. - -First, change your directory to `examples/aishell` and `source path.sh`. - -To start the demo's server, please run this in one console: - -```bash -CUDA_VISIBLE_DEVICES=0 bash local/server.sh -``` - -For the machine (might not be the same machine) to run the demo's client, please do the following installation before moving on. - -For example, on MAC OS X: - -```bash -brew install portaudio -pip install pyaudio -pip install keyboard -``` - -Then to start the client, please run this in another console: - -```bash -CUDA_VISIBLE_DEVICES=0 bash local/client.sh -``` - -Now, in the client console, press the `whitespace` key, hold, and start speaking. Until finishing your utterance, release the key to let the speech-to-text results shown in the console. To quit the client, just press `ESC` key. - -Notice that `deepspeech/exps/deepspeech2/deploy/client.py` must be run on a machine with a microphone device, while `deepspeech/exps/deepspeech2/deploy/server.py` could be run on one without any audio recording hardware, e.g. any remote server machine. Just be careful to set the `host_ip` and `host_port` argument with the actual accessible IP address and port, if the server and client are running with two separate machines. Nothing should be done if they are running on one single machine. - -Please also refer to `examples/aishell/local/server.sh`, which will first download a pre-trained Chinese model (trained with AISHELL1) and then start the demo server with the model. With running `examples/aishell/local/client.sh`, you can speak Chinese to test it. If you would like to try some other models, just update `--checkpoint_path` argument in the script.   diff --git a/docs/images/ds2offlineModel.png b/docs/images/ds2offlineModel.png new file mode 100644 index 000000000..0d8722ab0 Binary files /dev/null and b/docs/images/ds2offlineModel.png differ diff --git a/docs/images/ds2onlineModel.png b/docs/images/ds2onlineModel.png new file mode 100644 index 000000000..97a0e5619 Binary files /dev/null and b/docs/images/ds2onlineModel.png differ diff --git a/doc/src/augmentation.md b/docs/src/augmentation.md similarity index 100% rename from doc/src/augmentation.md rename to docs/src/augmentation.md diff --git a/doc/src/data_preparation.md b/docs/src/data_preparation.md similarity index 100% rename from doc/src/data_preparation.md rename to docs/src/data_preparation.md diff --git a/docs/src/deepspeech_architecture.md b/docs/src/deepspeech_architecture.md new file mode 100644 index 000000000..b93441222 --- /dev/null +++ b/docs/src/deepspeech_architecture.md @@ -0,0 +1,190 @@ +# Deepspeech2 +## Streaming + +The implemented arcitecure of Deepspeech2 online model is based on [Deepspeech2 model](https://arxiv.org/pdf/1512.02595.pdf) with some changes. +The model is mainly composed of 2D convolution subsampling layer and stacked single direction rnn layers. + +To illustrate the model implementation clearly, 3 parts are described in detail. +- Data Preparation +- Encoder +- Decoder + +In addition, the training process and the testing process are also introduced. + +The arcitecture of the model is shown in Fig.1. + +

+ +
Fig.1 The Arcitecture of deepspeech2 online model +

+ +### Data Preparation +#### Vocabulary +For English data, the vocabulary dictionary is composed of 26 English characters with " ' ", space, \ and \. The \ represents the blank label in CTC, the \ represents the unknown character and the \ represents the start and the end characters. For mandarin, the vocabulary dictionary is composed of chinese characters statisticed from the training set and three additional characters are added. The added characters are \, \ and \. For both English and mandarin data, we set the default indexs that \=0, \=1 and \= last index. +``` + # The code to build vocabulary + cd examples/aishell/s0 + python3 ../../../utils/build_vocab.py \ + --unit_type="char" \ + --count_threshold=0 \ + --vocab_path="data/vocab.txt" \ + --manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw" + +# vocabulary for aishell dataset (Mandarin) +vi examples/aishell/s0/data/vocab.txt + +# vocabulary for librispeech dataset (English) +vi examples/librispeech/s0/data/vocab.txt +``` + +#### CMVN +For CMVN, a subset or the full of traininig set is chosed and be used to compute the feature mean and std. +``` + # The code to compute the feature mean and std +cd examples/aishell/s0 +python3 ../../../utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --specgram_type="linear" \ + --delta_delta=false \ + --stride_ms=10.0 \ + --window_ms=20.0 \ + --sample_rate=16000 \ + --use_dB_normalization=True \ + --num_samples=2000 \ + --num_workers=10 \ + --output_path="data/mean_std.json" + +``` + +#### Feature Extraction + For feature extraction, three methods are implemented, which are linear (FFT without using filter bank), fbank and mfcc. + Currently, the released deepspeech2 online model use the linear feature extraction method. + ``` + The code for feature extraction + vi deepspeech/frontend/featurizer/audio_featurizer.py + ``` + +### Encoder +The encoder is composed of two 2D convolution subsampling layers and a number of stacked single direction rnn layers. The 2D convolution subsampling layers extract feature representation from the raw audio feature and reduce the length of audio feature at the same time. After passing through the convolution subsampling layers, then the feature representation are input into the stacked rnn layers. For the stacked rnn layers, LSTM cell and GRU cell are provided to use. Adding one fully connected (fc) layer after the stacked rnn layers is optional. If the number of stacked rnn layers is less than 5, adding one fc layer after stacked rnn layers is recommand. + +The code of Encoder is in: +``` +vi deepspeech/models/ds2_online/deepspeech2.py +``` + +### Decoder +To got the character possibilities of each frame, the feature representation of each frame output from the encoder are input into a projection layer which is implemented as a dense layer to do feature projection. The output dim of the projection layer is same with the vocabulary size. After projection layer, the softmax function is used to transform the frame-level feature representation be the possibilities of characters. While making model inference, the character possibilities of each frame are input into the CTC decoder to get the final speech recognition results. + +The code of the decoder is in: +``` +# The code of constructing the decoder in model +vi deepspeech/models/ds2_online/deepspeech2.py +# The code of CTC Decoder +vi deepspeech/modules/ctc.py +``` + +## Training Process +Using the command below, you can train the deepspeech2 online model. +``` + cd examples/aishell/s0 + bash run.sh --stage 0 --stop_stage 2 --model_type online --conf_path conf/deepspeech2_online.yaml +``` +The detail commands are: +``` +# The code for training in run.sh +set -e +source path.sh + +gpus=2,3,5,7 +stage=0 +stop_stage=5 +conf_path=conf/deepspeech2_online.yaml # conf/deepspeech2.yaml | conf/deepspeech2_online.yaml +avg_num=1 +model_type=online # online | offline + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +avg_ckpt=avg_${avg_num} +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `exp` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${model_type} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # avg n best model + avg.sh exp/${ckpt}/checkpoints ${avg_num} +fi +``` + +By using the command above, the training process can be started. There are 5 stages in "run.sh", and the first 3 stages are used for training process. The stage 0 is used for data preparation, in which the dataset will be downloaded, and the manifest files of the datasets, vocabulary dictionary and CMVN file will be generated in "./data/". The stage 1 is used for training the model, the log files and model checkpoint is saved in "exp/deepspeech2_online/". The stage 2 is used to generated final model for predicting by averaging the top-k model parameters based on validation loss. + +## Testing Process +Using the command below, you can test the deepspeech2 online model. + ``` + bash run.sh --stage 3 --stop_stage 5 --model_type online --conf_path conf/deepspeech2_online.yaml +``` +The detail commands are: +``` +conf_path=conf/deepspeech2_online.yaml +avg_num=1 +model_type=online +avg_ckpt=avg_${avg_num} + + if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=2 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${model_type}|| exit -1 +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # export ckpt avg_n + CUDA_VISIBLE_DEVICES=5 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type} +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # test export ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test_export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt}.jit ${model_type}|| exit -1 +fi + ``` +After the training process, we use stage 3,4,5 for testing process. The stage 3 is for testing the model generated in the stage 2 and provided the CER index of the test set. The stage 4 is for transforming the model from dynamic graph to static graph by using "paddle.jit" library. The stage 5 is for testing the model in static graph. + + +## Non-Streaming +The deepspeech2 offline model is similarity to the deepspeech2 online model. The main difference between them is the offline model use the stacked bi-directional rnn layers while the online model use the single direction rnn layers and the fc layer is not used. For the stacked bi-directional rnn layers in the offline model, the rnn cell and gru cell are provided to use. + +The arcitecture of the model is shown in Fig.2. +

+ +
Fig.2 The Arcitecture of deepspeech2 offline model +

+ + + +For data preparation and decoder, the deepspeech2 offline model is same with the deepspeech2 online model. + +The code of encoder and decoder for deepspeech2 offline model is in: +``` +vi deepspeech/models/ds2/deepspeech2.py +``` + +The training process and testing process of deepspeech2 offline model is very similary to deepspeech2 online model. +Only some changes should be noticed. + +For training and testing, the "model_type" and the "conf_path" must be set. + ``` +# Training offline +cd examples/aishell/s0 +bash run.sh --stage 0 --stop_stage 2 --model_type offline --conf_path conf/deepspeech2.yaml +``` +``` +# Testing offline +cd examples/aishell/s0 +bash run.sh --stage 3 --stop_stage 5 --model_type offline --conf_path conf/deepspeech2.yaml +``` diff --git a/doc/src/feature_list.md b/docs/src/feature_list.md similarity index 79% rename from doc/src/feature_list.md rename to docs/src/feature_list.md index b675d8100..4639ddd6f 100644 --- a/doc/src/feature_list.md +++ b/docs/src/feature_list.md @@ -1,13 +1,20 @@ # Features +### Dataset +* Aishell +* Librispeech +* THCHS30 +* TIMIT + ### Speech Recognition -* Offline +* Non-Streaming * [Baidu's DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf) * [Transformer](https://arxiv.org/abs/1706.03762) * [Conformer](https://arxiv.org/abs/2005.08100) -* Online +* Streaming + * [Baidu's DeepSpeech2](http://proceedings.mlr.press/v48/amodei16.pdf) * [U2](https://arxiv.org/pdf/2012.05481.pdf) ### Language Model @@ -22,6 +29,15 @@ * beam search * attention rescore +### Deployment + +* Paddle Inference + +### Aligment + +* MFA +* CTC Aligment + ### Speech Frontend * Audio diff --git a/doc/src/getting_started.md b/docs/src/getting_started.md similarity index 100% rename from doc/src/getting_started.md rename to docs/src/getting_started.md diff --git a/doc/src/install.md b/docs/src/install.md similarity index 95% rename from doc/src/install.md rename to docs/src/install.md index 01049a2fc..8cecba125 100644 --- a/doc/src/install.md +++ b/docs/src/install.md @@ -4,15 +4,16 @@ To avoid the trouble of environment setup, [running in Docker container](#runnin ## Prerequisites - Python >= 3.7 -- PaddlePaddle 2.0.0 or later (please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html)) +- PaddlePaddle latest version (please refer to the [Installation Guide](https://www.paddlepaddle.org.cn/documentation/docs/en/beginners_guide/index_en.html)) -## Setup +## Setup (Important) - Make sure these libraries or tools installed: `pkg-config`, `flac`, `ogg`, `vorbis`, `boost`, `sox, and `swig`, e.g. installing them via `apt-get`: ```bash sudo apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev ``` +The version of `swig` should >= 3.0 or, installing them via `yum`: diff --git a/doc/src/ngram_lm.md b/docs/src/ngram_lm.md similarity index 64% rename from doc/src/ngram_lm.md rename to docs/src/ngram_lm.md index 119a3b21c..7872df22d 100644 --- a/doc/src/ngram_lm.md +++ b/docs/src/ngram_lm.md @@ -35,52 +35,3 @@ Different from the English language model, Mandarin language model is character- * A whitespace character between two tokens is inserted. Please notice that the released language models only contain Chinese simplified characters. After preprocessing done we can begin to train the language model. The key training arguments for small LM is '-o 5 --prune 0 1 2 4 4' and '-o 5' for large LM. Please refer above section for the meaning of each argument. We also convert the arpa file to binary file using default settings. - - - -## [KenLM](http://kheafield.com/code/kenlm/) - -统计语言模型工具有比较多的选择,目前使用比较好的有srilm及kenlm,其中kenlm比srilm晚出来,训练速度也更快,而且支持单机大数据的训练。现在介绍一下kenlm的使用方法。 - -1. 工具包的下载地址:http://kheafield.com/code/kenlm.tar.gz - -2. 使用。该工具在linux环境下使用方便。 先确保linux环境已经按照1.36.0的Boost和zlib - - ``` - boost: - yum install boost - yum install boost-devel - - zlib: - yum install zlib - yum install zlib-devel - ``` - - 然后gcc版本需要是4.8.2及以上。 - - ``` - wget -O - https://kheafield.com/code/kenlm.tar.gz |tar xz - mkdir kenlm/build - cd kenlm/build - cmake .. - make -j2 - ``` - -3. 训练。使用如下命令进行训练: - - ``` - build/bin/lmplz -o 3 --verbose_header --text people2014corpus_words.txt --arpa result/people2014corpus_words.arps - ``` - - 其中, - 1)people2014corpus_words.txt文件必须是分词以后的文件。 - - 训练语料<人民日报2014版熟语料>,包括: 1)标准人工切词及词性数据people2014.tar.gz, 2)未切词文本数据people2014_words.txt, 3)kenlm训练字粒度语言模型文件及其二进制文件people2014corpus_chars.arps/klm, 4)kenlm词粒度语言模型文件及其二进制文件people2014corpus_words.arps/klm。 - - 2)-o后面的5表示的是5-gram,一般取到3即可,但可以结合自己实际情况判断。 - -4. 压缩。压缩模型为二进制,方便模型快速加载: - - ``` - build/bin/build_binary ./result/people2014corpus_words.arps ./result/people2014corpus_words.klm - ``` diff --git a/docs/src/reference.md b/docs/src/reference.md new file mode 100644 index 000000000..d3676fff2 --- /dev/null +++ b/docs/src/reference.md @@ -0,0 +1,8 @@ +# Reference + +We refer these repos to build `model` and `engine`: + +* [delta](https://github.com/Delta-ML/delta.git) +* [espnet](https://github.com/espnet/espnet.git) +* [kaldi](https://github.com/kaldi-asr/kaldi.git) +* [wenet](https://github.com/mobvoi/wenet) diff --git a/docs/src/released_model.md b/docs/src/released_model.md new file mode 100644 index 000000000..61fd1560e --- /dev/null +++ b/docs/src/released_model.md @@ -0,0 +1,28 @@ +# Released Models + +## Acoustic Model Released in paddle 2.X +Acoustic Model | Training Data | Token-based | Size | Descriptions | CER or WER | Hours of speech +:-------------:| :------------:| :-----: | -----: | :----------------- | :---------- | :--------- +[Ds2 Online Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds_online.5rnn.debug.tar.gz) | Aishell Dataset | Char-based | 345 MB | 2 Conv + 5 LSTM layers with only forward direction | 0.0824 | 151 h +[Ds2 Offline Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s0/aishell.s0.ds2.offline.cer6p65.release.tar.gz)| Aishell Dataset | Char-based | 306 MB | 2 Conv + 3 bidirectional GRU layers| 0.065 | 151 h +[Conformer Online Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.chunk.release.tar.gz) | Aishell Dataset | Char-based | 283 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention + CTC | 0.0594 | 151 h +[Conformer Offline Aishell Model](https://deepspeech.bj.bcebos.com/release2.1/aishell/s1/aishell.release.tar.gz) | Aishell Dataset | Char-based | 284 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention | 0.0547 | 151 h +[Conformer Librispeech Model](https://deepspeech.bj.bcebos.com/release2.1/librispeech/s1/conformer.release.tar.gz) | Librispeech Dataset | Word-based | 287 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention | 0.0325 | 960 h +[Transformer Librispeech Model](https://deepspeech.bj.bcebos.com/release2.1/librispeech/s1/transformer.release.tar.gz) | Librispeech Dataset | Word-based | 195 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention | 0.0544 | 960 h + +## Acoustic Model Transformed from paddle 1.8 +Acoustic Model | Training Data | Token-based | Size | Descriptions | CER or WER | Hours of speech +:-------------:| :------------:| :-----: | -----: | :----------------- | :---------- | :--------- +[Ds2 Offline Aishell model](https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz)|Aishell Dataset| Char-based| 234 MB| 2 Conv + 3 bidirectional GRU layers| 0.0804 | 151 h| +[Ds2 Offline Librispeech model](https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz)|Librispeech Dataset| Word-based| 307 MB| 2 Conv + 3 bidirectional sharing weight RNN layers | 0.0685| 960 h| +[Ds2 Offline Baidu en8k model](https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz)|Baidu Internal English Dataset| Word-based| 273 MB| 2 Conv + 3 bidirectional GRU layers | 0.0541 | 8628 h| + + + +## Language Model Released + +Language Model | Training Data | Token-based | Size | Descriptions +:-------------:| :------------:| :-----: | -----: | :----------------- +[English LM](https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm) | [CommonCrawl(en.00)](http://web-language-models.s3-website-us-east-1.amazonaws.com/ngrams/en/deduped/en.00.deduped.xz) | Word-based | 8.3 GB | Pruned with 0 1 1 1 1;
About 1.85 billion n-grams;
'trie' binary with '-a 22 -q 8 -b 8' +[Mandarin LM Small](https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm) | Baidu Internal Corpus | Char-based | 2.8 GB | Pruned with 0 1 2 4 4;
About 0.13 billion n-grams;
'probing' binary with default settings +[Mandarin LM Large](https://deepspeech.bj.bcebos.com/zh_lm/zhidao_giga.klm) | Baidu Internal Corpus | Char-based | 70.4 GB | No Pruning;
About 3.7 billion n-grams;
'probing' binary with default settings diff --git a/examples/1xt2x/.gitignore b/examples/1xt2x/.gitignore new file mode 100644 index 000000000..a9a5aecf4 --- /dev/null +++ b/examples/1xt2x/.gitignore @@ -0,0 +1 @@ +tmp diff --git a/examples/1xt2x/README.md b/examples/1xt2x/README.md new file mode 100644 index 000000000..1f5fe8e3b --- /dev/null +++ b/examples/1xt2x/README.md @@ -0,0 +1,11 @@ +# 1xt2x + +Convert Deepspeech 1.8 released model to 2.x. + +## Model +* Deepspeech2x + +## Exp +* baidu_en8k +* aishell +* librispeech diff --git a/examples/1xt2x/aishell/.gitignore b/examples/1xt2x/aishell/.gitignore new file mode 100644 index 000000000..7024e0e95 --- /dev/null +++ b/examples/1xt2x/aishell/.gitignore @@ -0,0 +1,4 @@ +exp +data +*log +tmp diff --git a/examples/1xt2x/aishell/conf/augmentation.json b/examples/1xt2x/aishell/conf/augmentation.json new file mode 100644 index 000000000..fe51488c7 --- /dev/null +++ b/examples/1xt2x/aishell/conf/augmentation.json @@ -0,0 +1 @@ +[] diff --git a/examples/1xt2x/aishell/conf/deepspeech2.yaml b/examples/1xt2x/aishell/conf/deepspeech2.yaml new file mode 100644 index 000000000..6e745e9d1 --- /dev/null +++ b/examples/1xt2x/aishell/conf/deepspeech2.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.0 + max_input_len: 27.0 # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 64 # one gpu + mean_std_filepath: data/mean_std.npz + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 1024 + use_gru: True + share_rnn_weights: False + blank_id: 4333 + +training: + n_epoch: 80 + accum_grad: 1 + lr: 2e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 3.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 32 + error_rate_type: cer + decoding_method: ctc_beam_search + lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm + alpha: 2.6 + beta: 5.0 + beam_size: 300 + cutoff_prob: 0.99 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/1xt2x/aishell/local/data.sh b/examples/1xt2x/aishell/local/data.sh new file mode 100755 index 000000000..1cde0c6ea --- /dev/null +++ b/examples/1xt2x/aishell/local/data.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + +bash local/download_model.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +tar xzvf aishell_model_v1.8_to_v2.x.tar.gz +mv aishell_v1.8.pdparams exp/deepspeech2/checkpoints/ +mv README.md exp/deepspeech2/ +mv mean_std.npz data/ +mv vocab.txt data/ +rm aishell_model_v1.8_to_v2.x.tar.gz -f + + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/aishell/aishell.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/aishell" + + if [ $? -ne 0 ]; then + echo "Prepare Aishell failed. Terminated." + exit 1 + fi + + for dataset in train dev test; do + mv data/manifest.${dataset} data/manifest.${dataset}.raw + done +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --specgram_type="linear" \ + --delta_delta=false \ + --stride_ms=10.0 \ + --window_ms=20.0 \ + --sample_rate=16000 \ + --use_dB_normalization=True \ + --num_samples=2000 \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for dataset in train dev test; do + { + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type "char" \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${dataset}.raw" \ + --output_path="data/manifest.${dataset}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest failed. Terminated." + exit 1 + fi + } & + done + wait +fi + +echo "Aishell data preparation done." +exit 0 diff --git a/examples/1xt2x/aishell/local/download_lm_ch.sh b/examples/1xt2x/aishell/local/download_lm_ch.sh new file mode 100755 index 000000000..ac27a9076 --- /dev/null +++ b/examples/1xt2x/aishell/local/download_lm_ch.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +. ${MAIN_ROOT}/utils/utility.sh + +DIR=data/lm +mkdir -p ${DIR} + +URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm' +MD5="29e02312deb2e59b3c8686c7966d4fe3" +TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm + + +echo "Download language model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download the language model!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/aishell/local/download_model.sh b/examples/1xt2x/aishell/local/download_model.sh new file mode 100644 index 000000000..2e4873ef6 --- /dev/null +++ b/examples/1xt2x/aishell/local/download_model.sh @@ -0,0 +1,19 @@ +#! /usr/bin/env bash + +. ${MAIN_ROOT}/utils/utility.sh + +URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz' +MD5=4ade113c69ea291b8ce5ec6a03296659 +TARGET=./aishell_model_v1.8_to_v2.x.tar.gz + + +echo "Download Aishell model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download Aishell model!" + exit 1 +fi +tar -zxvf $TARGET + + +exit 0 diff --git a/examples/1xt2x/aishell/local/test.sh b/examples/1xt2x/aishell/local/test.sh new file mode 100755 index 000000000..2ae0740b3 --- /dev/null +++ b/examples/1xt2x/aishell/local/test.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_prefix=$2 +model_type=$3 + +# download language model +bash local/download_lm_ch.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +python3 -u ${BIN_DIR}/test.py \ +--nproc ${ngpu} \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/aishell/path.sh b/examples/1xt2x/aishell/path.sh new file mode 100644 index 000000000..080ab1f79 --- /dev/null +++ b/examples/1xt2x/aishell/path.sh @@ -0,0 +1,16 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` +export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} +export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=deepspeech2 +export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin +echo "BIN_DIR "${BIN_DIR} diff --git a/examples/1xt2x/aishell/run.sh b/examples/1xt2x/aishell/run.sh new file mode 100755 index 000000000..482ab2a09 --- /dev/null +++ b/examples/1xt2x/aishell/run.sh @@ -0,0 +1,27 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=100 +conf_path=conf/deepspeech2.yaml +avg_num=1 +model_type=offline + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +v18_ckpt=aishell_v1.8 +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + mkdir -p exp/${ckpt}/checkpoints + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=1 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 +fi + diff --git a/examples/1xt2x/baidu_en8k/.gitignore b/examples/1xt2x/baidu_en8k/.gitignore new file mode 100644 index 000000000..7024e0e95 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/.gitignore @@ -0,0 +1,4 @@ +exp +data +*log +tmp diff --git a/examples/1xt2x/baidu_en8k/conf/augmentation.json b/examples/1xt2x/baidu_en8k/conf/augmentation.json new file mode 100644 index 000000000..fe51488c7 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/conf/augmentation.json @@ -0,0 +1 @@ +[] diff --git a/examples/1xt2x/baidu_en8k/conf/deepspeech2.yaml b/examples/1xt2x/baidu_en8k/conf/deepspeech2.yaml new file mode 100644 index 000000000..fbc7466f2 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/conf/deepspeech2.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test-clean + min_input_len: 0.0 + max_input_len: .inf # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 64 # one gpu + mean_std_filepath: data/mean_std.npz + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 1024 + use_gru: True + share_rnn_weights: False + blank_id: 28 + +training: + n_epoch: 80 + accum_grad: 1 + lr: 2e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 3.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 32 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 1.4 + beta: 0.35 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/1xt2x/baidu_en8k/local/data.sh b/examples/1xt2x/baidu_en8k/local/data.sh new file mode 100755 index 000000000..8f9468b13 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/local/data.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + + +bash local/download_model.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +tar xzvf baidu_en8k_v1.8_to_v2.x.tar.gz +mv baidu_en8k_v1.8.pdparams exp/deepspeech2/checkpoints/ +mv README.md exp/deepspeech2/ +mv mean_std.npz data/ +mv vocab.txt data/ +rm baidu_en8k_v1.8_to_v2.x.tar.gz -f + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/librispeech/librispeech.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/librispeech" \ + --full_download="True" + + if [ $? -ne 0 ]; then + echo "Prepare LibriSpeech failed. Terminated." + exit 1 + fi + + for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + mv data/manifest.${set} data/manifest.${set}.raw + done + + rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw + for set in train-clean-100 train-clean-360 train-other-500; do + cat data/manifest.${set}.raw >> data/manifest.train.raw + done + + for set in dev-clean dev-other; do + cat data/manifest.${set}.raw >> data/manifest.dev.raw + done + + for set in test-clean test-other; do + cat data/manifest.${set}.raw >> data/manifest.test.raw + done +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --num_samples=2000 \ + --specgram_type="linear" \ + --delta_delta=false \ + --sample_rate=16000 \ + --stride_ms=10.0 \ + --window_ms=20.0 \ + --use_dB_normalization=True \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for set in train dev test dev-clean dev-other test-clean test-other; do + { + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type ${unit_type} \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${set}.raw" \ + --output_path="data/manifest.${set}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest.${set} failed. Terminated." + exit 1 + fi + }& + done + wait +fi + +echo "LibriSpeech Data preparation done." +exit 0 + diff --git a/examples/1xt2x/baidu_en8k/local/download_lm_en.sh b/examples/1xt2x/baidu_en8k/local/download_lm_en.sh new file mode 100755 index 000000000..dc1bdf665 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/local/download_lm_en.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +. ${MAIN_ROOT}/utils/utility.sh + +DIR=data/lm +mkdir -p ${DIR} + +URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm +MD5="099a601759d467cd0a8523ff939819c5" +TARGET=${DIR}/common_crawl_00.prune01111.trie.klm + +echo "Download language model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download the language model!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/baidu_en8k/local/download_model.sh b/examples/1xt2x/baidu_en8k/local/download_model.sh new file mode 100644 index 000000000..6d06e3d6f --- /dev/null +++ b/examples/1xt2x/baidu_en8k/local/download_model.sh @@ -0,0 +1,19 @@ +#! /usr/bin/env bash + +. ${MAIN_ROOT}/utils/utility.sh + +URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz' +MD5=fdabeb6c96963ac85d9188f0275c6a1b +TARGET=./baidu_en8k_v1.8_to_v2.x.tar.gz + + +echo "Download BaiduEn8k model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download BaiduEn8k model!" + exit 1 +fi +tar -zxvf $TARGET + + +exit 0 diff --git a/examples/1xt2x/baidu_en8k/local/test.sh b/examples/1xt2x/baidu_en8k/local/test.sh new file mode 100755 index 000000000..4d00f30b8 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/local/test.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_prefix=$2 +model_type=$3 + +# download language model +bash local/download_lm_en.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +python3 -u ${BIN_DIR}/test.py \ +--nproc ${ngpu} \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/baidu_en8k/path.sh b/examples/1xt2x/baidu_en8k/path.sh new file mode 100644 index 000000000..080ab1f79 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/path.sh @@ -0,0 +1,16 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` +export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} +export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=deepspeech2 +export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin +echo "BIN_DIR "${BIN_DIR} diff --git a/examples/1xt2x/baidu_en8k/run.sh b/examples/1xt2x/baidu_en8k/run.sh new file mode 100755 index 000000000..c590312d1 --- /dev/null +++ b/examples/1xt2x/baidu_en8k/run.sh @@ -0,0 +1,27 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=100 +conf_path=conf/deepspeech2.yaml +avg_num=1 +model_type=offline + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +v18_ckpt=baidu_en8k_v1.8 +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + mkdir -p exp/${ckpt}/checkpoints + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 +fi + diff --git a/examples/1xt2x/deepspeech2x/__init__.py b/examples/1xt2x/deepspeech2x/__init__.py new file mode 100644 index 000000000..d85a3dde7 --- /dev/null +++ b/examples/1xt2x/deepspeech2x/__init__.py @@ -0,0 +1,370 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any +from typing import List +from typing import Tuple +from typing import Union + +import paddle +from paddle import nn +from paddle.fluid import core +from paddle.nn import functional as F + +from deepspeech.utils.log import Log + +#TODO(Hui Zhang): remove fluid import +logger = Log(__name__).getlog() + +########### hcak logging ############# +logger.warn = logger.warning + +########### hcak paddle ############# +paddle.half = 'float16' +paddle.float = 'float32' +paddle.double = 'float64' +paddle.short = 'int16' +paddle.int = 'int32' +paddle.long = 'int64' +paddle.uint16 = 'uint16' +paddle.cdouble = 'complex128' + + +def convert_dtype_to_string(tensor_dtype): + """ + Convert the data type in numpy to the data type in Paddle + Args: + tensor_dtype(core.VarDesc.VarType): the data type in numpy. + Returns: + core.VarDesc.VarType: the data type in Paddle. + """ + dtype = tensor_dtype + if dtype == core.VarDesc.VarType.FP32: + return paddle.float32 + elif dtype == core.VarDesc.VarType.FP64: + return paddle.float64 + elif dtype == core.VarDesc.VarType.FP16: + return paddle.float16 + elif dtype == core.VarDesc.VarType.INT32: + return paddle.int32 + elif dtype == core.VarDesc.VarType.INT16: + return paddle.int16 + elif dtype == core.VarDesc.VarType.INT64: + return paddle.int64 + elif dtype == core.VarDesc.VarType.BOOL: + return paddle.bool + elif dtype == core.VarDesc.VarType.BF16: + # since there is still no support for bfloat16 in NumPy, + # uint16 is used for casting bfloat16 + return paddle.uint16 + elif dtype == core.VarDesc.VarType.UINT8: + return paddle.uint8 + elif dtype == core.VarDesc.VarType.INT8: + return paddle.int8 + elif dtype == core.VarDesc.VarType.COMPLEX64: + return paddle.complex64 + elif dtype == core.VarDesc.VarType.COMPLEX128: + return paddle.complex128 + else: + raise ValueError("Not supported tensor dtype %s" % dtype) + + +if not hasattr(paddle, 'softmax'): + logger.warn("register user softmax to paddle, remove this when fixed!") + setattr(paddle, 'softmax', paddle.nn.functional.softmax) + +if not hasattr(paddle, 'log_softmax'): + logger.warn("register user log_softmax to paddle, remove this when fixed!") + setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax) + +if not hasattr(paddle, 'sigmoid'): + logger.warn("register user sigmoid to paddle, remove this when fixed!") + setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid) + +if not hasattr(paddle, 'log_sigmoid'): + logger.warn("register user log_sigmoid to paddle, remove this when fixed!") + setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid) + +if not hasattr(paddle, 'relu'): + logger.warn("register user relu to paddle, remove this when fixed!") + setattr(paddle, 'relu', paddle.nn.functional.relu) + + +def cat(xs, dim=0): + return paddle.concat(xs, axis=dim) + + +if not hasattr(paddle, 'cat'): + logger.warn( + "override cat of paddle if exists or register, remove this when fixed!") + paddle.cat = cat + + +########### hcak paddle.Tensor ############# +def item(x: paddle.Tensor): + return x.numpy().item() + + +if not hasattr(paddle.Tensor, 'item'): + logger.warn( + "override item of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.item = item + + +def func_long(x: paddle.Tensor): + return paddle.cast(x, paddle.long) + + +if not hasattr(paddle.Tensor, 'long'): + logger.warn( + "override long of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.long = func_long + +if not hasattr(paddle.Tensor, 'numel'): + logger.warn( + "override numel of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.numel = paddle.numel + + +def new_full(x: paddle.Tensor, + size: Union[List[int], Tuple[int], paddle.Tensor], + fill_value: Union[float, int, bool, paddle.Tensor], + dtype=None): + return paddle.full(size, fill_value, dtype=x.dtype) + + +if not hasattr(paddle.Tensor, 'new_full'): + logger.warn( + "override new_full of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.new_full = new_full + + +def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: + if convert_dtype_to_string(xs.dtype) == paddle.bool: + xs = xs.astype(paddle.int) + return xs.equal( + paddle.to_tensor( + ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place)) + + +if not hasattr(paddle.Tensor, 'eq'): + logger.warn( + "override eq of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.eq = eq + +if not hasattr(paddle, 'eq'): + logger.warn( + "override eq of paddle if exists or register, remove this when fixed!") + paddle.eq = eq + + +def contiguous(xs: paddle.Tensor) -> paddle.Tensor: + return xs + + +if not hasattr(paddle.Tensor, 'contiguous'): + logger.warn( + "override contiguous of paddle.Tensor if exists or register, remove this when fixed!" + ) + paddle.Tensor.contiguous = contiguous + + +def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: + nargs = len(args) + assert (nargs <= 1) + s = paddle.shape(xs) + if nargs == 1: + return s[args[0]] + else: + return s + + +#`to_static` do not process `size` property, maybe some `paddle` api dependent on it. +logger.warn( + "override size of paddle.Tensor " + "(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!" +) +paddle.Tensor.size = size + + +def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor: + return xs.reshape(args) + + +if not hasattr(paddle.Tensor, 'view'): + logger.warn("register user view to paddle.Tensor, remove this when fixed!") + paddle.Tensor.view = view + + +def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor: + return xs.reshape(ys.size()) + + +if not hasattr(paddle.Tensor, 'view_as'): + logger.warn( + "register user view_as to paddle.Tensor, remove this when fixed!") + paddle.Tensor.view_as = view_as + + +def is_broadcastable(shp1, shp2): + for a, b in zip(shp1[::-1], shp2[::-1]): + if a == 1 or b == 1 or a == b: + pass + else: + return False + return True + + +def masked_fill(xs: paddle.Tensor, + mask: paddle.Tensor, + value: Union[float, int]): + assert is_broadcastable(xs.shape, mask.shape) is True + bshape = paddle.broadcast_shape(xs.shape, mask.shape) + mask = mask.broadcast_to(bshape) + trues = paddle.ones_like(xs) * value + xs = paddle.where(mask, trues, xs) + return xs + + +if not hasattr(paddle.Tensor, 'masked_fill'): + logger.warn( + "register user masked_fill to paddle.Tensor, remove this when fixed!") + paddle.Tensor.masked_fill = masked_fill + + +def masked_fill_(xs: paddle.Tensor, + mask: paddle.Tensor, + value: Union[float, int]) -> paddle.Tensor: + assert is_broadcastable(xs.shape, mask.shape) is True + bshape = paddle.broadcast_shape(xs.shape, mask.shape) + mask = mask.broadcast_to(bshape) + trues = paddle.ones_like(xs) * value + ret = paddle.where(mask, trues, xs) + paddle.assign(ret.detach(), output=xs) + return xs + + +if not hasattr(paddle.Tensor, 'masked_fill_'): + logger.warn( + "register user masked_fill_ to paddle.Tensor, remove this when fixed!") + paddle.Tensor.masked_fill_ = masked_fill_ + + +def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: + val = paddle.full_like(xs, value) + paddle.assign(val.detach(), output=xs) + return xs + + +if not hasattr(paddle.Tensor, 'fill_'): + logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!") + paddle.Tensor.fill_ = fill_ + + +def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor: + return paddle.tile(xs, size) + + +if not hasattr(paddle.Tensor, 'repeat'): + logger.warn( + "register user repeat to paddle.Tensor, remove this when fixed!") + paddle.Tensor.repeat = repeat + +if not hasattr(paddle.Tensor, 'softmax'): + logger.warn( + "register user softmax to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax) + +if not hasattr(paddle.Tensor, 'sigmoid'): + logger.warn( + "register user sigmoid to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid) + +if not hasattr(paddle.Tensor, 'relu'): + logger.warn("register user relu to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu) + + +def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor: + return x.astype(other.dtype) + + +if not hasattr(paddle.Tensor, 'type_as'): + logger.warn( + "register user type_as to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'type_as', type_as) + + +def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: + assert len(args) == 1 + if isinstance(args[0], str): # dtype + return x.astype(args[0]) + elif isinstance(args[0], paddle.Tensor): #Tensor + return x.astype(args[0].dtype) + else: # Device + return x + + +if not hasattr(paddle.Tensor, 'to'): + logger.warn("register user to to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'to', to) + + +def func_float(x: paddle.Tensor) -> paddle.Tensor: + return x.astype(paddle.float) + + +if not hasattr(paddle.Tensor, 'float'): + logger.warn("register user float to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'float', func_float) + + +def func_int(x: paddle.Tensor) -> paddle.Tensor: + return x.astype(paddle.int) + + +if not hasattr(paddle.Tensor, 'int'): + logger.warn("register user int to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'int', func_int) + + +def tolist(x: paddle.Tensor) -> List[Any]: + return x.numpy().tolist() + + +if not hasattr(paddle.Tensor, 'tolist'): + logger.warn( + "register user tolist to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'tolist', tolist) + + +########### hcak paddle.nn ############# +class GLU(nn.Layer): + """Gated Linear Units (GLU) Layer""" + + def __init__(self, dim: int=-1): + super().__init__() + self.dim = dim + + def forward(self, xs): + return F.glu(xs, axis=self.dim) + + +if not hasattr(paddle.nn, 'GLU'): + logger.warn("register user GLU to paddle.nn, remove this when fixed!") + setattr(paddle.nn, 'GLU', GLU) diff --git a/examples/1xt2x/deepspeech2x/bin/test.py b/examples/1xt2x/deepspeech2x/bin/test.py new file mode 100644 index 000000000..3fa0a61de --- /dev/null +++ b/examples/1xt2x/deepspeech2x/bin/test.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation for DeepSpeech2 model.""" +from deepspeech2x.model import DeepSpeech2Tester as Tester + +from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.training.cli import default_argument_parser +from deepspeech.utils.utility import print_arguments + + +def main_sp(config, args): + exp = Tester(config, args) + exp.setup() + exp.run_test() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + parser.add_argument("--model_type") + # save asr result to + parser.add_argument( + "--result_file", type=str, help="path of save the asr result") + args = parser.parse_args() + print_arguments(args, globals()) + if args.model_type is None: + args.model_type = 'offline' + print("model_type:{}".format(args.model_type)) + + # https://yaml.org/type/float.html + config = get_cfg_defaults(args.model_type) + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/examples/1xt2x/deepspeech2x/model.py b/examples/1xt2x/deepspeech2x/model.py new file mode 100644 index 000000000..cbbc502d2 --- /dev/null +++ b/examples/1xt2x/deepspeech2x/model.py @@ -0,0 +1,427 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains DeepSpeech2 and DeepSpeech2Online model.""" +import time +from collections import defaultdict +from contextlib import nullcontext +from pathlib import Path +from typing import Optional + +import numpy as np +import paddle +from deepspeech2x.models.ds2 import DeepSpeech2InferModel +from deepspeech2x.models.ds2 import DeepSpeech2Model +from paddle import distributed as dist +from paddle.io import DataLoader +from yacs.config import CfgNode + +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.dataset import ManifestDataset +from deepspeech.io.sampler import SortagradBatchSampler +from deepspeech.io.sampler import SortagradDistributedBatchSampler +from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline +from deepspeech.models.ds2_online import DeepSpeech2ModelOnline +from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog +from deepspeech.training.trainer import Trainer +from deepspeech.utils import error_rate +from deepspeech.utils import layer_tools +from deepspeech.utils import mp_tools +from deepspeech.utils.log import Log +#from deepspeech.utils.log import Autolog + +logger = Log(__name__).getlog() + + +class DeepSpeech2Trainer(Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # training config + default = CfgNode( + dict( + lr=5e-4, # learning rate + lr_decay=1.0, # learning rate decay + weight_decay=1e-6, # the coeff of weight decay + global_grad_clip=5.0, # the global norm clip + n_epoch=50, # train epochs + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def train_batch(self, batch_index, batch_data, msg): + train_conf = self.config.training + start = time.time() + + # forward + utt, audio, audio_len, text, text_len = batch_data + loss = self.model(audio, audio_len, text, text_len) + losses_np = { + 'train_loss': float(loss), + } + + # loss backward + if (batch_index + 1) % train_conf.accum_grad != 0: + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + context = self.model.no_sync + else: + # Used for single gpu training and DDP gradient synchronization + # processes. + context = nullcontext + + with context(): + loss.backward() + layer_tools.print_grads(self.model, print_func=None) + + # optimizer step + if (batch_index + 1) % train_conf.accum_grad == 0: + self.optimizer.step() + self.optimizer.clear_grad() + self.iteration += 1 + + iteration_time = time.time() - start + + msg += "train time: {:>.3f}s, ".format(iteration_time) + msg += "batch size: {}, ".format(self.config.collator.batch_size) + msg += "accum: {}, ".format(train_conf.accum_grad) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_np.items()) + logger.info(msg) + + if dist.get_rank() == 0 and self.visualizer: + for k, v in losses_np.items(): + # `step -1` since we update `step` after optimizer.step(). + self.visualizer.add_scalar("train/{}".format(k), v, + self.iteration - 1) + + @paddle.no_grad() + def valid(self): + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + self.model.eval() + valid_losses = defaultdict(list) + num_seen_utts = 1 + total_loss = 0.0 + for i, batch in enumerate(self.valid_loader): + utt, audio, audio_len, text, text_len = batch + loss = self.model(audio, audio_len, text, text_len) + if paddle.isfinite(loss): + num_utts = batch[1].shape[0] + num_seen_utts += num_utts + total_loss += float(loss) * num_utts + valid_losses['val_loss'].append(float(loss)) + + if (i + 1) % self.config.training.log_interval == 0: + valid_dump = {k: np.mean(v) for k, v in valid_losses.items()} + valid_dump['val_history_loss'] = total_loss / num_seen_utts + + # logging + msg = f"Valid: Rank: {dist.get_rank()}, " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in valid_dump.items()) + logger.info(msg) + + logger.info('Rank {} Val info val_loss {}'.format( + dist.get_rank(), total_loss / num_seen_utts)) + return total_loss, num_seen_utts + + def setup_model(self): + config = self.config.clone() + config.defrost() + config.model.feat_size = self.train_loader.collate_fn.feature_size + #config.model.dict_size = self.train_loader.collate_fn.vocab_size + config.model.dict_size = len(self.train_loader.collate_fn.vocab_list) + config.freeze() + + if self.args.model_type == 'offline': + model = DeepSpeech2Model.from_config(config.model) + elif self.args.model_type == 'online': + model = DeepSpeech2ModelOnline.from_config(config.model) + else: + raise Exception("wrong model type") + if self.parallel: + model = paddle.DataParallel(model) + + logger.info(f"{model}") + layer_tools.print_params(model, logger.info) + + grad_clip = ClipGradByGlobalNormWithLog( + config.training.global_grad_clip) + lr_scheduler = paddle.optimizer.lr.ExponentialDecay( + learning_rate=config.training.lr, + gamma=config.training.lr_decay, + verbose=True) + optimizer = paddle.optimizer.Adam( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=paddle.regularizer.L2Decay( + config.training.weight_decay), + grad_clip=grad_clip) + + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + logger.info("Setup model/optimizer/lr_scheduler!") + + def setup_dataloader(self): + config = self.config.clone() + config.defrost() + config.collator.keep_transcription_text = False + + config.data.manifest = config.data.train_manifest + train_dataset = ManifestDataset.from_config(config) + + config.data.manifest = config.data.dev_manifest + dev_dataset = ManifestDataset.from_config(config) + + config.data.manifest = config.data.test_manifest + test_dataset = ManifestDataset.from_config(config) + + if self.parallel: + batch_sampler = SortagradDistributedBatchSampler( + train_dataset, + batch_size=config.collator.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + else: + batch_sampler = SortagradBatchSampler( + train_dataset, + shuffle=True, + batch_size=config.collator.batch_size, + drop_last=True, + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) + + collate_fn_train = SpeechCollator.from_config(config) + + config.collator.augmentation_config = "" + collate_fn_dev = SpeechCollator.from_config(config) + + config.collator.keep_transcription_text = True + config.collator.augmentation_config = "" + collate_fn_test = SpeechCollator.from_config(config) + + self.train_loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn_train, + num_workers=config.collator.num_workers) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config.collator.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_dev) + self.test_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn_test) + if "" in self.test_loader.collate_fn.vocab_list: + self.test_loader.collate_fn.vocab_list.remove("") + if "" in self.valid_loader.collate_fn.vocab_list: + self.valid_loader.collate_fn.vocab_list.remove("") + if "" in self.train_loader.collate_fn.vocab_list: + self.train_loader.collate_fn.vocab_list.remove("") + logger.info("Setup train/valid/test Dataloader!") + + +class DeepSpeech2Tester(DeepSpeech2Trainer): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # testing config + default = CfgNode( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, config, args): + super().__init__(config, args) + + def ordid2token(self, texts, texts_len): + """ ord() id to chr() chr """ + trans = [] + for text, n in zip(texts, texts_len): + n = n.numpy().item() + ids = text[:n] + trans.append(''.join([chr(i) for i in ids])) + return trans + + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): + cfg = self.config.decoding + errors_sum, len_refs, num_ins = 0.0, 0, 0 + errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors + error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer + + vocab_list = self.test_loader.collate_fn.vocab_list + if "" in vocab_list: + space_id = vocab_list.index("") + vocab_list[space_id] = " " + + target_transcripts = self.ordid2token(texts, texts_len) + + result_transcripts = self.compute_result_transcripts(audio, audio_len, + vocab_list, cfg) + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): + errors, len_ref = errors_func(target, result) + errors_sum += errors + len_refs += len_ref + num_ins += 1 + if fout: + fout.write(utt + " " + result + "\n") + logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % + (target, result)) + logger.info("Current error rate [%s] = %f" % + (cfg.error_rate_type, error_rate_func(target, result))) + + return dict( + errors_sum=errors_sum, + len_refs=len_refs, + num_ins=num_ins, + error_rate=errors_sum / len_refs, + error_rate_type=cfg.error_rate_type) + + def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): + result_transcripts = self.model.decode( + audio, + audio_len, + vocab_list, + decoding_method=cfg.decoding_method, + lang_model_path=cfg.lang_model_path, + beam_alpha=cfg.alpha, + beam_beta=cfg.beta, + beam_size=cfg.beam_size, + cutoff_prob=cfg.cutoff_prob, + cutoff_top_n=cfg.cutoff_top_n, + num_processes=cfg.num_proc_bsearch) + return result_transcripts + + @mp_tools.rank_zero_only + @paddle.no_grad() + def test(self): + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + self.model.eval() + cfg = self.config + error_rate_type = None + errors_sum, len_refs, num_ins = 0.0, 0, 0 + with open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + utts, audio, audio_len, texts, texts_len = batch + metrics = self.compute_metrics(utts, audio, audio_len, texts, + texts_len, fout) + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + logger.info("Error rate [%s] (%d/?) = %f" % + (error_rate_type, num_ins, errors_sum / len_refs)) + + # logging + msg = "Test: " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "Final error rate [%s] (%d/%d) = %f" % ( + error_rate_type, num_ins, num_ins, errors_sum / len_refs) + logger.info(msg) + + # self.autolog.report() + + def run_test(self): + self.resume_or_scratch() + try: + self.test() + except KeyboardInterrupt: + exit(-1) + + def export(self): + if self.args.model_type == 'offline': + infer_model = DeepSpeech2InferModel.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) + elif self.args.model_type == 'online': + infer_model = DeepSpeech2InferModelOnline.from_pretrained( + self.test_loader, self.config, self.args.checkpoint_path) + else: + raise Exception("wrong model type") + + infer_model.eval() + feat_dim = self.test_loader.collate_fn.feature_size + static_model = infer_model.export() + logger.info(f"Export code: {static_model.forward.code}") + paddle.jit.save(static_model, self.args.export_path) + + def run_export(self): + try: + self.export() + except KeyboardInterrupt: + exit(-1) + + def setup(self): + """Setup the experiment. + """ + paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu') + + self.setup_output_dir() + self.setup_checkpointer() + + self.setup_dataloader() + self.setup_model() + + self.iteration = 0 + self.epoch = 0 + + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + if self.args.output: + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + else: + output_dir = Path( + self.args.checkpoint_path).expanduser().parent.parent + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir diff --git a/examples/1xt2x/deepspeech2x/models/__init__.py b/examples/1xt2x/deepspeech2x/models/__init__.py new file mode 100644 index 000000000..185a92b8d --- /dev/null +++ b/examples/1xt2x/deepspeech2x/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/1xt2x/deepspeech2x/models/ds2/__init__.py b/examples/1xt2x/deepspeech2x/models/ds2/__init__.py new file mode 100644 index 000000000..39bea5bf9 --- /dev/null +++ b/examples/1xt2x/deepspeech2x/models/ds2/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .deepspeech2 import DeepSpeech2InferModel +from .deepspeech2 import DeepSpeech2Model + +__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] diff --git a/examples/1xt2x/deepspeech2x/models/ds2/deepspeech2.py b/examples/1xt2x/deepspeech2x/models/ds2/deepspeech2.py new file mode 100644 index 000000000..f154ddb54 --- /dev/null +++ b/examples/1xt2x/deepspeech2x/models/ds2/deepspeech2.py @@ -0,0 +1,314 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Deepspeech2 ASR Model""" +from typing import Optional + +import paddle +from deepspeech2x.models.ds2.rnn import RNNStack +from paddle import nn +from yacs.config import CfgNode + +from deepspeech.models.ds2.conv import ConvStack +from deepspeech.modules.ctc import CTCDecoder +from deepspeech.utils import layer_tools +from deepspeech.utils.checkpoint import Checkpoint +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + +__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel'] + + +class CRNNEncoder(nn.Layer): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True): + super().__init__() + self.rnn_size = rnn_size + self.feat_size = feat_size # 161 for linear + self.dict_size = dict_size + + self.conv = ConvStack(feat_size, num_conv_layers) + + i_size = self.conv.output_height # H after conv stack + self.rnn = RNNStack( + i_size=i_size, + h_size=rnn_size, + num_stacks=num_rnn_layers, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + + @property + def output_size(self): + return self.rnn_size * 2 + + def forward(self, audio, audio_len): + """Compute Encoder outputs + + Args: + audio (Tensor): [B, Tmax, D] + text (Tensor): [B, Umax] + audio_len (Tensor): [B] + text_len (Tensor): [B] + Returns: + x (Tensor): encoder outputs, [B, T, D] + x_lens (Tensor): encoder length, [B] + """ + # [B, T, D] -> [B, D, T] + audio = audio.transpose([0, 2, 1]) + # [B, D, T] -> [B, C=1, D, T] + x = audio.unsqueeze(1) + x_lens = audio_len + + # convolution group + x, x_lens = self.conv(x, x_lens) + x_val = x.numpy() + + # convert data from convolution feature map to sequence of vectors + #B, C, D, T = paddle.shape(x) # not work under jit + x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] + #x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit + x = x.reshape([0, 0, -1]) #[B, T, C*D] + + # remove padding part + x, x_lens = self.rnn(x, x_lens) #[B, T, D] + return x, x_lens + + +class DeepSpeech2Model(nn.Layer): + """The DeepSpeech2 network structure. + + :param audio_data: Audio spectrogram data layer. + :type audio_data: Variable + :param text_data: Transcription text data layer. + :type text_data: Variable + :param audio_len: Valid sequence length data layer. + :type audio_len: Variable + :param masks: Masks data layer to reset padding. + :type masks: Variable + :param dict_size: Dictionary size for tokenized transcription. + :type dict_size: int + :param num_conv_layers: Number of stacking convolution layers. + :type num_conv_layers: int + :param num_rnn_layers: Number of stacking RNN layers. + :type num_rnn_layers: int + :param rnn_size: RNN layer size (dimension of RNN cells). + :type rnn_size: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward direction RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: A tuple of an output unnormalized log probability layer ( + before softmax) and a ctc cost layer. + :rtype: tuple of LayerOutput + """ + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + num_conv_layers=2, #Number of stacking convolution layers. + num_rnn_layers=3, #Number of stacking RNN layers. + rnn_layer_size=1024, #RNN layer size (number of RNN cells). + use_gru=True, #Use gru if set True. Use simple rnn if set False. + share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + )) + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True, + blank_id=0): + super().__init__() + self.encoder = CRNNEncoder( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + assert (self.encoder.output_size == rnn_size * 2) + + self.decoder = CTCDecoder( + odim=dict_size, # is in vocab + enc_n_units=self.encoder.output_size, + blank_id=blank_id, # first token is + dropout_rate=0.0, + reduction=True, # sum + batch_average=True) # sum / batch_size + + def forward(self, audio, audio_len, text, text_len): + """Compute Model loss + + Args: + audio (Tenosr): [B, T, D] + audio_len (Tensor): [B] + text (Tensor): [B, U] + text_len (Tensor): [B] + + Returns: + loss (Tenosr): [1] + """ + eouts, eouts_len = self.encoder(audio, audio_len) + loss = self.decoder(eouts, eouts_len, text, text_len) + return loss + + @paddle.no_grad() + def decode(self, audio, audio_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes): + # init once + # decoders only accept string encoded in utf-8 + self.decoder.init_decode( + beam_alpha=beam_alpha, + beam_beta=beam_beta, + lang_model_path=lang_model_path, + vocab_list=vocab_list, + decoding_method=decoding_method) + + eouts, eouts_len = self.encoder(audio, audio_len) + probs = self.decoder.softmax(eouts) + print("probs.shape", probs.shape) + return self.decoder.decode_probs( + probs.numpy(), eouts_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes) + + def decode_probs_split(self, probs_split, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, + cutoff_prob, cutoff_top_n, num_processes): + self.decoder.init_decode( + beam_alpha=beam_alpha, + beam_beta=beam_beta, + lang_model_path=lang_model_path, + vocab_list=vocab_list, + decoding_method=decoding_method) + return self.decoder.decode_probs_split( + probs_split, vocab_list, decoding_method, lang_model_path, + beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, + num_processes) + + @classmethod + def from_pretrained(cls, dataloader, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + Parameters + ---------- + dataloader: paddle.io.DataLoader + + config: yacs.config.CfgNode + model configs + + checkpoint_path: Path or str + the path of pretrained model checkpoint, without extension name + + Returns + ------- + DeepSpeech2Model + The model built from pretrained result. + """ + model = cls(feat_size=dataloader.collate_fn.feature_size, + dict_size=len(dataloader.collate_fn.vocab_list), + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + use_gru=config.model.use_gru, + share_rnn_weights=config.model.share_rnn_weights) + infos = Checkpoint().load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") + layer_tools.summary(model) + return model + + @classmethod + def from_config(cls, config): + """Build a DeepSpeec2Model from config + Parameters + + config: yacs.config.CfgNode + config.model + Returns + ------- + DeepSpeech2Model + The model built from config. + """ + model = cls(feat_size=config.feat_size, + dict_size=config.dict_size, + num_conv_layers=config.num_conv_layers, + num_rnn_layers=config.num_rnn_layers, + rnn_size=config.rnn_layer_size, + use_gru=config.use_gru, + share_rnn_weights=config.share_rnn_weights, + blank_id=config.blank_id) + return model + + +class DeepSpeech2InferModel(DeepSpeech2Model): + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True, + blank_id=0): + super().__init__( + feat_size=feat_size, + dict_size=dict_size, + num_conv_layers=num_conv_layers, + num_rnn_layers=num_rnn_layers, + rnn_size=rnn_size, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights, + blank_id=blank_id) + + def forward(self, audio, audio_len): + """export model function + + Args: + audio (Tensor): [B, T, D] + audio_len (Tensor): [B] + + Returns: + probs: probs after softmax + """ + eouts, eouts_len = self.encoder(audio, audio_len) + probs = self.decoder.softmax(eouts) + return probs, eouts_len + + def export(self): + static_model = paddle.jit.to_static( + self, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None, self.encoder.feat_size], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + ]) + return static_model diff --git a/examples/1xt2x/deepspeech2x/models/ds2/rnn.py b/examples/1xt2x/deepspeech2x/models/ds2/rnn.py new file mode 100644 index 000000000..e45db7c05 --- /dev/null +++ b/examples/1xt2x/deepspeech2x/models/ds2/rnn.py @@ -0,0 +1,334 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from deepspeech.modules.activation import brelu +from deepspeech.modules.mask import make_non_pad_mask +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + +__all__ = ['RNNStack'] + + +class RNNCell(nn.RNNCellBase): + r""" + Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it + computes the outputs and updates states. + The formula used is as follows: + .. math:: + h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) + y_{t} & = h_{t} + + where :math:`act` is for :attr:`activation`. + """ + + def __init__(self, + hidden_size: int, + activation="tanh", + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + if activation not in ["tanh", "relu", "brelu"]: + raise ValueError( + "activation for SimpleRNNCell should be tanh or relu, " + "but get {}".format(activation)) + self.activation = activation + self._activation_fn = paddle.tanh \ + if activation == "tanh" \ + else F.relu + if activation == 'brelu': + self._activation_fn = brelu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + pre_h = states + i2h = inputs + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self._activation_fn(i2h + h2h) + return h, h + + @property + def state_shape(self): + return (self.hidden_size, ) + + +class GRUCell(nn.RNNCellBase): + r""" + Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, + it computes the outputs and updates states. + The formula for GRU used is as follows: + .. math:: + r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) + z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) + \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) + h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise + multiplication operator. + """ + + def __init__(self, + input_size: int, + hidden_size: int, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (3 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = None + self.bias_hh = self.create_parameter( + (3 * hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + self.input_size = input_size + self._gate_activation = F.sigmoid + self._activation = paddle.relu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + + pre_hidden = states # shape [batch_size, hidden_size] + + x_gates = inputs + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + bias_u, bias_r, bias_c = paddle.split( + self.bias_hh, num_or_sections=3, axis=0) + + weight_hh = paddle.transpose( + self.weight_hh, + perm=[1, 0]) #weight_hh:shape[hidden_size, 3 * hidden_size] + w_u_r_c = paddle.flatten(weight_hh) + size_u_r = self.hidden_size * 2 * self.hidden_size + w_u_r = paddle.reshape(w_u_r_c[:size_u_r], + (self.hidden_size, self.hidden_size * 2)) + w_u, w_r = paddle.split(w_u_r, num_or_sections=2, axis=1) + w_c = paddle.reshape(w_u_r_c[size_u_r:], + (self.hidden_size, self.hidden_size)) + + h_u = paddle.matmul( + pre_hidden, w_u, + transpose_y=False) + bias_u #shape [batch_size, hidden_size] + h_r = paddle.matmul( + pre_hidden, w_r, + transpose_y=False) + bias_r #shape [batch_size, hidden_size] + + x_u, x_r, x_c = paddle.split( + x_gates, num_or_sections=3, axis=1) #shape[batch_size, hidden_size] + + u = self._gate_activation(x_u + h_u) #shape [batch_size, hidden_size] + r = self._gate_activation(x_r + h_r) #shape [batch_size, hidden_size] + c = self._activation( + x_c + paddle.matmul(r * pre_hidden, w_c, transpose_y=False) + + bias_c) # [batch_size, hidden_size] + + h = (1 - u) * pre_hidden + u * c + # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru + return h, h + + @property + def state_shape(self): + r""" + The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch + size would be automatically inserted into shape). The shape corresponds + to the shape of :math:`h_{t-1}`. + """ + return (self.hidden_size, ) + + +class BiRNNWithBN(nn.Layer): + """Bidirectonal simple rnn layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param size: Dimension of RNN cells. + :type size: int + :param share_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + :type share_weights: bool + :return: Bidirectional simple rnn layer. + :rtype: Variable + """ + + def __init__(self, i_size: int, h_size: int, share_weights: bool): + super().__init__() + self.share_weights = share_weights + if self.share_weights: + #input-hidden weights shared between bi-directional rnn. + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + # batch norm is only performed on input-state projection + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = self.fw_fc + self.bw_bn = self.fw_bn + else: + self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + h_size, bias_attr=None, data_format='NLC') + + self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu') + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.bw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class BiGRUWithBN(nn.Layer): + """Bidirectonal gru layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + + :param name: Name of the layer. + :type name: string + :param input: Input layer. + :type input: Variable + :param size: Dimension of GRU cells. + :type size: int + :param act: Activation type. + :type act: string + :return: Bidirectional GRU layer. + :rtype: Variable + """ + + def __init__(self, i_size: int, h_size: int): + super().__init__() + hidden_size = h_size * 3 + + self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.fw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False) + self.bw_bn = nn.BatchNorm1D( + hidden_size, bias_attr=None, data_format='NLC') + + self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size) + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.bw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x, x_len): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class RNNStack(nn.Layer): + """RNN group with stacked bidirectional simple RNN or GRU layers. + + :param input: Input layer. + :type input: Variable + :param size: Dimension of RNN cells in each layer. + :type size: int + :param num_stacks: Number of stacked rnn layers. + :type num_stacks: int + :param use_gru: Use gru if set True. Use simple rnn if set False. + :type use_gru: bool + :param share_rnn_weights: Whether to share input-hidden weights between + forward and backward directional RNNs. + It is only available when use_gru=False. + :type share_weights: bool + :return: Output layer of the RNN group. + :rtype: Variable + """ + + def __init__(self, + i_size: int, + h_size: int, + num_stacks: int, + use_gru: bool, + share_rnn_weights: bool): + super().__init__() + rnn_stacks = [] + for i in range(num_stacks): + if use_gru: + #default:GRU using tanh + rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size)) + else: + rnn_stacks.append( + BiRNNWithBN( + i_size=i_size, + h_size=h_size, + share_weights=share_rnn_weights)) + i_size = h_size * 2 + + self.rnn_stacks = nn.LayerList(rnn_stacks) + + def forward(self, x: paddle.Tensor, x_len: paddle.Tensor): + """ + x: shape [B, T, D] + x_len: shpae [B] + """ + for i, rnn in enumerate(self.rnn_stacks): + x, x_len = rnn(x, x_len) + masks = make_non_pad_mask(x_len) #[B, T] + masks = masks.unsqueeze(-1) # [B, T, 1] + # TODO(Hui Zhang): not support bool multiply + masks = masks.astype(x.dtype) + x = x.multiply(masks) + return x, x_len diff --git a/examples/1xt2x/librispeech/.gitignore b/examples/1xt2x/librispeech/.gitignore new file mode 100644 index 000000000..7024e0e95 --- /dev/null +++ b/examples/1xt2x/librispeech/.gitignore @@ -0,0 +1,4 @@ +exp +data +*log +tmp diff --git a/examples/1xt2x/librispeech/conf/augmentation.json b/examples/1xt2x/librispeech/conf/augmentation.json new file mode 100644 index 000000000..fe51488c7 --- /dev/null +++ b/examples/1xt2x/librispeech/conf/augmentation.json @@ -0,0 +1 @@ +[] diff --git a/examples/1xt2x/librispeech/conf/deepspeech2.yaml b/examples/1xt2x/librispeech/conf/deepspeech2.yaml new file mode 100644 index 000000000..edef07972 --- /dev/null +++ b/examples/1xt2x/librispeech/conf/deepspeech2.yaml @@ -0,0 +1,67 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test-clean + min_input_len: 0.0 + max_input_len: 1000.0 # second + min_output_len: 0.0 + max_output_len: .inf + min_output_input_ratio: 0.00 + max_output_input_ratio: .inf + +collator: + batch_size: 64 # one gpu + mean_std_filepath: data/mean_std.npz + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 2048 + use_gru: False + share_rnn_weights: True + blank_id: 28 + +training: + n_epoch: 80 + accum_grad: 1 + lr: 2e-3 + lr_decay: 0.83 + weight_decay: 1e-06 + global_grad_clip: 3.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + +decoding: + batch_size: 32 + error_rate_type: wer + decoding_method: ctc_beam_search + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 500 + cutoff_prob: 1.0 + cutoff_top_n: 40 + num_proc_bsearch: 8 diff --git a/examples/1xt2x/librispeech/local/data.sh b/examples/1xt2x/librispeech/local/data.sh new file mode 100755 index 000000000..22a86bb2e --- /dev/null +++ b/examples/1xt2x/librispeech/local/data.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + + +bash local/download_model.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +tar xzvf librispeech_v1.8_to_v2.x.tar.gz +mv librispeech_v1.8.pdparams exp/deepspeech2/checkpoints/ +mv README.md exp/deepspeech2/ +mv mean_std.npz data/ +mv vocab.txt data/ +rm librispeech_v1.8_to_v2.x.tar.gz -f + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # download data, generate manifests + python3 ${TARGET_DIR}/librispeech/librispeech.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/librispeech" \ + --full_download="True" + + if [ $? -ne 0 ]; then + echo "Prepare LibriSpeech failed. Terminated." + exit 1 + fi + + for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do + mv data/manifest.${set} data/manifest.${set}.raw + done + + rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw + for set in train-clean-100 train-clean-360 train-other-500; do + cat data/manifest.${set}.raw >> data/manifest.train.raw + done + + for set in dev-clean dev-other; do + cat data/manifest.${set}.raw >> data/manifest.dev.raw + done + + for set in test-clean test-other; do + cat data/manifest.${set}.raw >> data/manifest.test.raw + done +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --num_samples=2000 \ + --specgram_type="linear" \ + --delta_delta=false \ + --sample_rate=16000 \ + --stride_ms=10.0 \ + --window_ms=20.0 \ + --use_dB_normalization=True \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for set in train dev test dev-clean dev-other test-clean test-other; do + { + python3 ${MAIN_ROOT}/utils/format_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type ${unit_type} \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${set}.raw" \ + --output_path="data/manifest.${set}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest.${set} failed. Terminated." + exit 1 + fi + }& + done + wait +fi + +echo "LibriSpeech Data preparation done." +exit 0 + diff --git a/examples/1xt2x/librispeech/local/download_lm_en.sh b/examples/1xt2x/librispeech/local/download_lm_en.sh new file mode 100755 index 000000000..dc1bdf665 --- /dev/null +++ b/examples/1xt2x/librispeech/local/download_lm_en.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +. ${MAIN_ROOT}/utils/utility.sh + +DIR=data/lm +mkdir -p ${DIR} + +URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm +MD5="099a601759d467cd0a8523ff939819c5" +TARGET=${DIR}/common_crawl_00.prune01111.trie.klm + +echo "Download language model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download the language model!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/librispeech/local/download_model.sh b/examples/1xt2x/librispeech/local/download_model.sh new file mode 100644 index 000000000..cc6a9ec75 --- /dev/null +++ b/examples/1xt2x/librispeech/local/download_model.sh @@ -0,0 +1,19 @@ +#! /usr/bin/env bash + +. ${MAIN_ROOT}/utils/utility.sh + +URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz' +MD5=7b0f582fe2f5a840b840e7ee52246bc5 +TARGET=./librispeech_v1.8_to_v2.x.tar.gz + + +echo "Download LibriSpeech model ..." +download $URL $MD5 $TARGET +if [ $? -ne 0 ]; then + echo "Fail to download LibriSpeech model!" + exit 1 +fi +tar -zxvf $TARGET + + +exit 0 diff --git a/examples/1xt2x/librispeech/local/test.sh b/examples/1xt2x/librispeech/local/test.sh new file mode 100755 index 000000000..4d00f30b8 --- /dev/null +++ b/examples/1xt2x/librispeech/local/test.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +if [ $# != 3 ];then + echo "usage: ${0} config_path ckpt_path_prefix model_type" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_prefix=$2 +model_type=$3 + +# download language model +bash local/download_lm_en.sh +if [ $? -ne 0 ]; then + exit 1 +fi + +python3 -u ${BIN_DIR}/test.py \ +--nproc ${ngpu} \ +--config ${config_path} \ +--result_file ${ckpt_prefix}.rsl \ +--checkpoint_path ${ckpt_prefix} \ +--model_type ${model_type} + +if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 +fi + + +exit 0 diff --git a/examples/1xt2x/librispeech/path.sh b/examples/1xt2x/librispeech/path.sh new file mode 100644 index 000000000..080ab1f79 --- /dev/null +++ b/examples/1xt2x/librispeech/path.sh @@ -0,0 +1,16 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` +export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} +export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +MODEL=deepspeech2 +export BIN_DIR=${LOCAL_DEEPSPEECH2}/deepspeech2x/bin +echo "BIN_DIR "${BIN_DIR} diff --git a/examples/1xt2x/librispeech/run.sh b/examples/1xt2x/librispeech/run.sh new file mode 100755 index 000000000..05706a428 --- /dev/null +++ b/examples/1xt2x/librispeech/run.sh @@ -0,0 +1,26 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=100 +conf_path=conf/deepspeech2.yaml +avg_num=1 +model_type=offline + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +v18_ckpt=librispeech_v1.8 +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + mkdir -p exp/${ckpt}/checkpoints + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1 +fi diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index e5ebfcbaf..ee0f1405e 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -10,8 +10,11 @@ | Model | Params | Release | Config | Test set | Loss | CER | | --- | --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 | +| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug | test | 6.016139030456543 | 0.066549 | +| --- | --- | --- | --- | --- | --- | --- | +| DeepSpeech2 | 58.4M | 7181e427 | conf/deepspeech2.yaml + spec aug | test | 5.71956205368042 | 0.064287 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | | DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | | DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | +| --- | --- | --- | --- | --- | --- | --- | | DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 | diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 7f0a1462f..9560930ac 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -40,9 +40,12 @@ model: rnn_layer_size: 1024 use_gru: True share_rnn_weights: False + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 80 + accum_grad: 1 lr: 2e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/aishell/s0/conf/deepspeech2_online.yaml b/examples/aishell/s0/conf/deepspeech2_online.yaml index fdc3a5365..7e87594cc 100644 --- a/examples/aishell/s0/conf/deepspeech2_online.yaml +++ b/examples/aishell/s0/conf/deepspeech2_online.yaml @@ -36,17 +36,20 @@ collator: model: num_conv_layers: 2 - num_rnn_layers: 3 + num_rnn_layers: 5 rnn_layer_size: 1024 rnn_direction: forward # [forward, bidirect] - num_fc_layers: 1 - fc_layers_size_list: 512, + num_fc_layers: 0 + fc_layers_size_list: -1, use_gru: False - + blank_id: 0 + ctc_grad_norm_type: instance + training: n_epoch: 50 + accum_grad: 1 lr: 2e-3 - lr_decay: 0.91 # 0.83 + lr_decay: 0.9 # 0.83 weight_decay: 1e-06 global_grad_clip: 3.0 log_interval: 100 @@ -59,7 +62,7 @@ decoding: error_rate_type: cer decoding_method: ctc_beam_search lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm - alpha: 1.9 + alpha: 2.2 #1.9 beta: 5.0 beam_size: 300 cutoff_prob: 0.99 diff --git a/examples/aishell/s0/local/client.sh b/examples/aishell/s0/local/client.sh deleted file mode 100755 index 3b59ad3df..000000000 --- a/examples/aishell/s0/local/client.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -source path.sh - -# run on MacOS -# brew install portaudio -# pip install pyaudio -# pip install keyboard - -# start demo client -python3 -u ${BIN_DIR}/deploy/client.py \ ---host_ip="localhost" \ ---host_port=8086 \ - -if [ $? -ne 0 ]; then - echo "Failed in starting demo client!" - exit 1 -fi - -exit 0 diff --git a/examples/aishell/s0/local/export.sh b/examples/aishell/s0/local/export.sh index 2e09e5f5e..a5e62c28d 100755 --- a/examples/aishell/s0/local/export.sh +++ b/examples/aishell/s0/local/export.sh @@ -13,13 +13,7 @@ ckpt_path_prefix=$2 jit_model_export_path=$3 model_type=$4 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - python3 -u ${BIN_DIR}/export.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ diff --git a/examples/aishell/s0/local/server.sh b/examples/aishell/s0/local/server.sh deleted file mode 100755 index 2b8810993..000000000 --- a/examples/aishell/s0/local/server.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash -# TODO: replace the model with a mandarin model - -if [[ $# != 1 ]];then - echo "usage: $1 checkpoint_path" - exit -1 -fi - -source path.sh - -# download language model -bash local/download_lm_ch.sh -if [ $? -ne 0 ]; then - exit 1 -fi - -# download well-trained model -#bash local/download_model.sh -#if [ $? -ne 0 ]; then -# exit 1 -#fi - -# start demo server -CUDA_VISIBLE_DEVICES=0 \ -python3 -u ${BIN_DIR}/deploy/server.py \ ---device 'gpu' \ ---nproc 1 \ ---config conf/deepspeech2.yaml \ ---host_ip="localhost" \ ---host_port=8086 \ ---speech_save_dir="demo_cache" \ ---checkpoint_path ${1} - -if [ $? -ne 0 ]; then - echo "Failed in starting demo server!" - exit 1 -fi - - -exit 0 diff --git a/examples/aishell/s0/local/test.sh b/examples/aishell/s0/local/test.sh index 9fd0bc8d5..2ae0740b3 100755 --- a/examples/aishell/s0/local/test.sh +++ b/examples/aishell/s0/local/test.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 model_type=$3 @@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then fi python3 -u ${BIN_DIR}/test.py \ ---device ${device} \ ---nproc 1 \ +--nproc ${ngpu} \ --config ${config_path} \ --result_file ${ckpt_prefix}.rsl \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/aishell/s0/local/test_export.sh b/examples/aishell/s0/local/test_export.sh index b6d580979..a9a6b122d 100755 --- a/examples/aishell/s0/local/test_export.sh +++ b/examples/aishell/s0/local/test_export.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 jit_model_export_path=$2 model_type=$3 @@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then fi python3 -u ${BIN_DIR}/test_export.py \ ---device ${device} \ ---nproc 1 \ +--nproc ${ngpu} \ --config ${config_path} \ --result_file ${jit_model_export_path}.rsl \ --export_path ${jit_model_export_path} \ diff --git a/examples/aishell/s0/local/train.sh b/examples/aishell/s0/local/train.sh index 3438a7357..edbf33830 100755 --- a/examples/aishell/s0/local/train.sh +++ b/examples/aishell/s0/local/train.sh @@ -12,27 +12,22 @@ config_path=$1 ckpt_name=$2 model_type=$3 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - mkdir -p exp +# seed may break model convergence seed=10086 -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ --model_type ${model_type} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/aishell/s0/local/tune.sh b/examples/aishell/s0/local/tune.sh deleted file mode 100755 index 59406cd5b..000000000 --- a/examples/aishell/s0/local/tune.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash - -# grid-search for hyper-parameters in language model -python3 -u ${BIN_DIR}/tune.py \ ---device 'gpu' \ ---nproc 1 \ ---config conf/deepspeech2.yaml \ ---num_batches=10 \ ---batch_size=128 \ ---beam_size=300 \ ---num_proc_bsearch=8 \ ---num_alphas=10 \ ---num_betas=10 \ ---alpha_from=0.0 \ ---alpha_to=5.0 \ ---beta_from=-6 \ ---beta_to=6 \ ---cutoff_prob=1.0 \ ---cutoff_top_n=40 \ ---checkpoint_path ${1} - -if [ $? -ne 0 ]; then - echo "Failed in tuning!" - exit 1 -fi - - -exit 0 diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index e5ab12a59..71191c3ac 100755 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -27,7 +27,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index 3e606788e..6f8ae135f 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -76,6 +76,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index 4b1430c58..a4248459c 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -71,6 +71,8 @@ model: # hybrid CTC/attention model_conf: ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: instance lsm_weight: 0.1 # label smoothing option length_normalized_loss: false diff --git a/examples/aishell/s1/local/align.sh b/examples/aishell/s1/local/align.sh index ad6c84bc8..279461aaf 100755 --- a/examples/aishell/s1/local/align.sh +++ b/examples/aishell/s1/local/align.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 @@ -22,8 +18,7 @@ mkdir -p ${output_dir} # align dump in `result_file` # .tier, .TextGrid dump in `dir of result_file` python3 -u ${BIN_DIR}/alignment.py \ ---device ${device} \ ---nproc 1 \ +--nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.align \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/aishell/s1/local/export.sh b/examples/aishell/s1/local/export.sh index f99a15bad..b562218e7 100755 --- a/examples/aishell/s1/local/export.sh +++ b/examples/aishell/s1/local/export.sh @@ -12,13 +12,7 @@ config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - python3 -u ${BIN_DIR}/export.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ diff --git a/examples/aishell/s1/local/test.sh b/examples/aishell/s1/local/test.sh index f7e99ad7f..c87412c9b 100755 --- a/examples/aishell/s1/local/test.sh +++ b/examples/aishell/s1/local/test.sh @@ -8,11 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - config_path=$1 ckpt_prefix=$2 @@ -39,8 +34,7 @@ for type in attention ctc_greedy_search; do output_dir=${ckpt_prefix} mkdir -p ${output_dir} python3 -u ${BIN_DIR}/test.py \ - --device ${device} \ - --nproc 1 \ + --nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ @@ -58,8 +52,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do output_dir=${ckpt_prefix} mkdir -p ${output_dir} python3 -u ${BIN_DIR}/test.py \ - --device ${device} \ - --nproc 1 \ + --nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/aishell/s1/local/train.sh b/examples/aishell/s1/local/train.sh index ec17054ab..71af3a006 100755 --- a/examples/aishell/s1/local/train.sh +++ b/examples/aishell/s1/local/train.sh @@ -1,37 +1,43 @@ #!/bin/bash +profiler_options= +benchmark_batch_size=0 +benchmark_max_step=0 + +# seed may break model convergence +seed=0 + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +if [ ${seed} != 0 ]; then + export FLAGS_cudnn_deterministic=True + echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." +fi + if [ $# != 2 ];then echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" exit -1 fi -ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') -echo "using $ngpu gpus..." - config_path=$1 ckpt_name=$2 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi -echo "using ${device}..." - mkdir -p exp -seed=1024 -if [ ${seed} ]; then - export FLAGS_cudnn_deterministic=True -fi - python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ +--seed ${seed} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ ---seed ${seed} +--profiler-options "${profiler_options}" \ +--benchmark-batch-size ${benchmark_batch_size} \ +--benchmark-max-step ${benchmark_max_step} + -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/aishell/s1/run.sh b/examples/aishell/s1/run.sh index d55d47ea6..e3c008234 100644 --- a/examples/aishell/s1/run.sh +++ b/examples/aishell/s1/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/callcenter/s1/local/align.sh b/examples/callcenter/s1/local/align.sh index f2c878c20..b679e2ea7 100755 --- a/examples/callcenter/s1/local/align.sh +++ b/examples/callcenter/s1/local/align.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 @@ -20,7 +16,6 @@ ckpt_name=$(basename ${ckpt_prefxi}) mkdir -p exp - batch_size=1 output_dir=${ckpt_prefix} mkdir -p ${output_dir} @@ -28,8 +23,7 @@ mkdir -p ${output_dir} # align dump in `result_file` # .tier, .TextGrid dump in `dir of result_file` python3 -u ${BIN_DIR}/alignment.py \ ---device ${device} \ ---nproc 1 \ +--nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.align \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/callcenter/s1/local/export.sh b/examples/callcenter/s1/local/export.sh index d171899cd..d5f912e90 100755 --- a/examples/callcenter/s1/local/export.sh +++ b/examples/callcenter/s1/local/export.sh @@ -12,13 +12,7 @@ config_path=$1 ckpt_path_prefix=$2 jit_model_export_path=$3 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - python3 -u ${BIN_DIR}/export.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ diff --git a/examples/callcenter/s1/local/test.sh b/examples/callcenter/s1/local/test.sh index 7a5b1cdb1..dca3137dd 100755 --- a/examples/callcenter/s1/local/test.sh +++ b/examples/callcenter/s1/local/test.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 @@ -32,8 +28,7 @@ for type in attention ctc_greedy_search; do output_dir=${ckpt_prefix} mkdir -p ${output_dir} python3 -u ${BIN_DIR}/test.py \ - --device ${device} \ - --nproc 1 \ + --nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ @@ -51,8 +46,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do output_dir=${ckpt_prefix} mkdir -p ${output_dir} python3 -u ${BIN_DIR}/test.py \ - --device ${device} \ - --nproc 1 \ + --nproc ${ngpu} \ --config ${config_path} \ --result_file ${output_dir}/${type}.rsl \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/callcenter/s1/local/train.sh b/examples/callcenter/s1/local/train.sh index 928c6492c..eb8f86626 100755 --- a/examples/callcenter/s1/local/train.sh +++ b/examples/callcenter/s1/local/train.sh @@ -11,27 +11,23 @@ echo "using $ngpu gpus..." config_path=$1 ckpt_name=$2 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi echo "using ${device}..." mkdir -p exp -seed=1024 -if [ ${seed} ]; then +# seed may break model convergence +seed=0 +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/callcenter/s1/run.sh b/examples/callcenter/s1/run.sh index 52dd44eca..305021f19 100644 --- a/examples/callcenter/s1/run.sh +++ b/examples/callcenter/s1/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/cc-cedict/README.md b/examples/cc-cedict/README.md index e69de29bb..513fca533 100644 --- a/examples/cc-cedict/README.md +++ b/examples/cc-cedict/README.md @@ -0,0 +1,58 @@ +# [CC-CEDICT](https://cc-cedict.org/wiki/) + +What is CC-CEDICT? +CC-CEDICT is a continuation of the CEDICT project. +The objective of the CEDICT project was to create an online, downloadable (as opposed to searchable-only) public-domain Chinese-English dictionary. +CEDICT was started by Paul Andrew Denisowski in October 1997. +For the most part, the project is modeled on Jim Breen's highly successful EDICT (Japanese-English dictionary) project and is intended to be a collaborative effort, +with users providing entries and corrections to the main file. + + +## Parse CC-CEDICT to Json format + +1. Parse to Json + +``` +run.sh +``` + +2. Result + +``` +exp/ +|-- cedict +`-- cedict.json + +0 directories, 2 files +``` + +``` +4c4bffc84e24467fe1b2ea9ba37ed6b6 exp/cedict +3adf504dacd13886f88cc9fe3b37c75d exp/cedict.json +``` + +``` +==> exp/cedict <== +# CC-CEDICT +# Community maintained free Chinese-English dictionary. +# +# Published by MDBG +# +# License: +# Creative Commons Attribution-ShareAlike 4.0 International License +# https://creativecommons.org/licenses/by-sa/4.0/ +# +# Referenced works: + +==> exp/cedict.json <== +{"traditional": "2019\u51a0\u72c0\u75c5\u6bd2\u75c5", "simplified": "2019\u51a0\u72b6\u75c5\u6bd2\u75c5", "pinyin": "er4 ling2 yi1 jiu3 guan1 zhuang4 bing4 du2 bing4", "english": "COVID-19, the coronavirus disease identified in 2019"} +{"traditional": "21\u4e09\u9ad4\u7d9c\u5408\u75c7", "simplified": "21\u4e09\u4f53\u7efc\u5408\u75c7", "pinyin": "er4 shi2 yi1 san1 ti3 zong1 he2 zheng4", "english": "trisomy"} +{"traditional": "3C", "simplified": "3C", "pinyin": "san1 C", "english": "abbr. for computers, communications, and consumer electronics"} +{"traditional": "3P", "simplified": "3P", "pinyin": "san1 P", "english": "(slang) threesome"} +{"traditional": "3Q", "simplified": "3Q", "pinyin": "san1 Q", "english": "(Internet slang) thank you (loanword)"} +{"traditional": "421", "simplified": "421", "pinyin": "si4 er4 yi1", "english": "four grandparents, two parents and an only child"} +{"traditional": "502\u81a0", "simplified": "502\u80f6", "pinyin": "wu3 ling2 er4 jiao1", "english": "cyanoacrylate glue"} +{"traditional": "88", "simplified": "88", "pinyin": "ba1 ba1", "english": "(Internet slang) bye-bye (alternative for \u62dc\u62dc[bai2 bai2])"} +{"traditional": "996", "simplified": "996", "pinyin": "jiu3 jiu3 liu4", "english": "9am-9pm, six days a week (work schedule)"} +{"traditional": "A", "simplified": "A", "pinyin": "A", "english": "(slang) (Tw) to steal"} +``` diff --git a/examples/chinese_g2p/README.md b/examples/chinese_g2p/README.md deleted file mode 100644 index e3fdfe684..000000000 --- a/examples/chinese_g2p/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Download Baker dataset - -Baker dataset has to be downloaded mannually and moved to 'data/', because you will have to pass the CATTCHA from a browswe to download the dataset. - -Download URL https://test.data-baker.com/#/data/index/source. diff --git a/examples/chinese_g2p/.gitignore b/examples/g2p/.gitignore similarity index 100% rename from examples/chinese_g2p/.gitignore rename to examples/g2p/.gitignore diff --git a/examples/g2p/README.md b/examples/g2p/README.md new file mode 100644 index 000000000..4ec5922b3 --- /dev/null +++ b/examples/g2p/README.md @@ -0,0 +1,3 @@ +# G2P + +* zh - Chinese G2P diff --git a/examples/g2p/zh/README.md b/examples/g2p/zh/README.md new file mode 100644 index 000000000..de5573565 --- /dev/null +++ b/examples/g2p/zh/README.md @@ -0,0 +1,93 @@ +# G2P + +* WS +jieba +* G2P +pypinyin +* Tone sandhi +simple + +We recommend using [Paraket](https://github.com/PaddlePaddle/Parakeet] [TextFrontEnd](https://github.com/PaddlePaddle/Parakeet/blob/develop/parakeet/frontend/__init__.py) to do G2P. +The phoneme set should be changed, you can reference `examples/thchs30/a0/data/dict/syllable.lexicon`. + +## Download Baker dataset + +[Baker](https://test.data-baker.com/#/data/index/source) dataset has to be downloaded mannually and moved to './data', +because you will have to pass the `CATTCHA` from a browswe to download the dataset. + + +## RUN + +``` +. path.sh +./run.sh +``` + +## Result + +``` +exp/ +|-- 000001-010000.txt +|-- ref.pinyin +|-- trans.jieba.pinyin +`-- trans.pinyin + +0 directories, 4 files +``` + +``` +4f5a368441eb16aaf43dc1972f8b63dd exp/000001-010000.txt +01707896391c2de9b6fc4a39654be942 exp/ref.pinyin +43380ef160f65a23a3a0544700aa49b8 exp/trans.jieba.pinyin +8e6ff1fc22d8e8584082e804e8bcdeb7 exp/trans.pinyin +``` + +``` +==> exp/000001-010000.txt <== +000001 卡尔普#2陪外孙#1玩滑梯#4。 + ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 假语村言#2别再#1拥抱我#4。 + jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 宝马#1配挂#1跛骡鞍#3,貂蝉#1怨枕#2董翁榻#4。 + bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 邓小平#2与#1撒切尔#2会晤#4。 + deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4 +000005 老虎#1幼崽#2与#1宠物犬#1玩耍#4。 + lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3 + +==> exp/ref.pinyin <== +000001 ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 jia2 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 bao2 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 deng4 xiao3 ping2 yu3 sa4 qie4 er3 hui4 wu4 +000005 lao2 hu3 you4 zai3 yu2 chong3 wu4 quan3 wan2 shua3 +000006 shen1 chang2 yue1 wu2 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4 +000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1 +000008 zhan2 pin3 sui1 you3 zhan3 yuan2 que4 tui2 +000009 yi2 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3 +000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1 + +==> exp/trans.jieba.pinyin <== +000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4 +000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3 +000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4 +000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1 +000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2 +000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3 +000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1 + +==> exp/trans.pinyin <== +000001 ka3 er3 pu3 pei2 wai4 sun1 wan2 hua2 ti1 +000002 jia3 yu3 cun1 yan2 bie2 zai4 yong1 bao4 wo3 +000003 bao3 ma3 pei4 gua4 bo3 luo2 an1 diao1 chan2 yuan4 zhen3 dong3 weng1 ta4 +000004 deng4 xiao3 ping2 yu3 sa1 qie4 er3 hui4 wu4 +000005 lao3 hu3 you4 zai3 yu3 chong3 wu4 quan3 wan2 shua3 +000006 shen1 chang2 yue1 wu3 chi3 er4 cun4 wu3 fen1 huo4 yi3 shang4 +000007 zhao4 di2 yue1 cao2 yun2 teng2 qu4 gui3 wu1 +000008 zhan3 pin3 sui1 you3 zhan3 yuan2 que4 tui2 +000009 yi3 san3 ju1 er2 tong2 he2 you4 tuo1 er2 tong2 wei2 zhu3 +000010 ke1 te4 ni1 shen1 chuan1 bao4 wen2 da4 yi1 +``` diff --git a/examples/chinese_g2p/local/convert_transcription.py b/examples/g2p/zh/local/convert_transcription.py similarity index 100% rename from examples/chinese_g2p/local/convert_transcription.py rename to examples/g2p/zh/local/convert_transcription.py diff --git a/examples/chinese_g2p/local/extract_pinyin_label.py b/examples/g2p/zh/local/extract_pinyin_label.py similarity index 100% rename from examples/chinese_g2p/local/extract_pinyin_label.py rename to examples/g2p/zh/local/extract_pinyin_label.py diff --git a/examples/chinese_g2p/local/ignore_sandhi.py b/examples/g2p/zh/local/ignore_sandhi.py similarity index 100% rename from examples/chinese_g2p/local/ignore_sandhi.py rename to examples/g2p/zh/local/ignore_sandhi.py diff --git a/examples/chinese_g2p/local/prepare_dataset.sh b/examples/g2p/zh/local/prepare_dataset.sh similarity index 100% rename from examples/chinese_g2p/local/prepare_dataset.sh rename to examples/g2p/zh/local/prepare_dataset.sh diff --git a/examples/chinese_g2p/path.sh b/examples/g2p/zh/path.sh similarity index 82% rename from examples/chinese_g2p/path.sh rename to examples/g2p/zh/path.sh index 482177dc6..f475ed833 100644 --- a/examples/chinese_g2p/path.sh +++ b/examples/g2p/zh/path.sh @@ -1,4 +1,4 @@ -export MAIN_ROOT=`realpath ${PWD}/../../` +export MAIN_ROOT=`realpath ${PWD}/../../../` export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export LC_ALL=C diff --git a/examples/chinese_g2p/requirements.txt b/examples/g2p/zh/requirements.txt similarity index 100% rename from examples/chinese_g2p/requirements.txt rename to examples/g2p/zh/requirements.txt diff --git a/examples/chinese_g2p/run.sh b/examples/g2p/zh/run.sh similarity index 82% rename from examples/chinese_g2p/run.sh rename to examples/g2p/zh/run.sh index 8197dce4b..25b713110 100755 --- a/examples/chinese_g2p/run.sh +++ b/examples/g2p/zh/run.sh @@ -6,16 +6,19 @@ stage=-1 stop_stage=100 exp_dir=exp -data_dir=data +data=data source ${MAIN_ROOT}/utils/parse_options.sh || exit -1 mkdir -p ${exp_dir} +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ];then + test -e ${data}/BZNSYP.rar || { echo "Please download BZNSYP.rar and put it in ${data}; exit -1; } +fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then echo "stage 0: Extracting Prosody Labeling" - bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data_dir} + bash local/prepare_dataset.sh --exp-dir ${exp_dir} --data-dir ${data} fi # convert transcription in chinese into pinyin with pypinyin or jieba+pypinyin diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 5603d3c8a..11bcf5f65 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -1,10 +1,17 @@ # LibriSpeech +## Data +| Data Subset | Duration in Seconds | +| --- | --- | +| data/manifest.train | 0.83s ~ 29.735s | +| data/manifest.dev | 1.065 ~ 35.155s | +| data/manifest.test-clean | 1.285s ~ 34.955s | + ## Deepspeech2 | Model | Params | release | Config | Test set | Loss | WER | | --- | --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | 14.49190807 | test-clean | 0.067283 | -| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | -| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | +| DeepSpeech2 | 42.96M | 2.2.0 | conf/deepspeech2.yaml + spec_aug | test-clean | 14.49190807 | 0.067283 | +| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | test-clean | 15.184467315673828 | 0.072154 | +| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | test-clean | - | 0.073973 | | DeepSpeech2 | 42.96M | 1.8.5 | - | test-clean | - | 0.074939 | diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index dab8d0462..3f1a376f1 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -4,7 +4,7 @@ data: dev_manifest: data/manifest.dev-clean test_manifest: data/manifest.test-clean min_input_len: 0.0 - max_input_len: 27.0 # second + max_input_len: 30.0 # second min_output_len: 0.0 max_output_len: .inf min_output_input_ratio: 0.00 @@ -40,9 +40,12 @@ model: rnn_layer_size: 2048 use_gru: False share_rnn_weights: True + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 50 + accum_grad: 1 lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/librispeech/s0/conf/deepspeech2_online.yaml b/examples/librispeech/s0/conf/deepspeech2_online.yaml index 2e4aed40a..180a6205f 100644 --- a/examples/librispeech/s0/conf/deepspeech2_online.yaml +++ b/examples/librispeech/s0/conf/deepspeech2_online.yaml @@ -4,14 +4,14 @@ data: dev_manifest: data/manifest.dev-clean test_manifest: data/manifest.test-clean min_input_len: 0.0 - max_input_len: 27.0 # second + max_input_len: 30.0 # second min_output_len: 0.0 max_output_len: .inf min_output_input_ratio: 0.00 max_output_input_ratio: .inf collator: - batch_size: 20 + batch_size: 15 mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt @@ -42,9 +42,12 @@ model: num_fc_layers: 2 fc_layers_size_list: 512, 256 use_gru: False + blank_id: 0 + ctc_grad_norm_type: instance training: n_epoch: 50 + accum_grad: 4 lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 diff --git a/examples/librispeech/s0/local/export.sh b/examples/librispeech/s0/local/export.sh index 2e09e5f5e..a5e62c28d 100755 --- a/examples/librispeech/s0/local/export.sh +++ b/examples/librispeech/s0/local/export.sh @@ -13,13 +13,7 @@ ckpt_path_prefix=$2 jit_model_export_path=$3 model_type=$4 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi - python3 -u ${BIN_DIR}/export.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --checkpoint_path ${ckpt_path_prefix} \ diff --git a/examples/librispeech/s0/local/test.sh b/examples/librispeech/s0/local/test.sh index b5b68c599..4d00f30b8 100755 --- a/examples/librispeech/s0/local/test.sh +++ b/examples/librispeech/s0/local/test.sh @@ -8,10 +8,6 @@ fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi config_path=$1 ckpt_prefix=$2 model_type=$3 @@ -23,8 +19,7 @@ if [ $? -ne 0 ]; then fi python3 -u ${BIN_DIR}/test.py \ ---device ${device} \ ---nproc 1 \ +--nproc ${ngpu} \ --config ${config_path} \ --result_file ${ckpt_prefix}.rsl \ --checkpoint_path ${ckpt_prefix} \ diff --git a/examples/librispeech/s0/local/train.sh b/examples/librispeech/s0/local/train.sh index dcd21df34..519df7fe9 100755 --- a/examples/librispeech/s0/local/train.sh +++ b/examples/librispeech/s0/local/train.sh @@ -12,28 +12,22 @@ config_path=$1 ckpt_name=$2 model_type=$3 -device=gpu -if [ ${ngpu} == 0 ];then - device=cpu -fi -echo "using ${device}..." - mkdir -p exp -seed=1024 -if [ ${seed} ]; then +# seed may break model convergence +seed=0 +if [ ${seed} != 0 ]; then export FLAGS_cudnn_deterministic=True fi python3 -u ${BIN_DIR}/train.py \ ---device ${device} \ --nproc ${ngpu} \ --config ${config_path} \ --output exp/${ckpt_name} \ --model_type ${model_type} \ --seed ${seed} -if [ ${seed} ]; then +if [ ${seed} != 0 ]; then unset FLAGS_cudnn_deterministic fi diff --git a/examples/librispeech/s0/local/tune.sh b/examples/librispeech/s0/local/tune.sh deleted file mode 100755 index c344e77e5..000000000 --- a/examples/librispeech/s0/local/tune.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash - -if [ $# != 1 ];then - echo "usage: tune ckpt_path" - exit 1 -fi - -# grid-search for hyper-parameters in language model -python3 -u ${BIN_DIR}/tune.py \ ---device 'gpu' \ ---nproc 1 \ ---config conf/deepspeech2.yaml \ ---num_batches=-1 \ ---batch_size=128 \ ---beam_size=500 \ ---num_proc_bsearch=12 \ ---num_alphas=45 \ ---num_betas=8 \ ---alpha_from=1.0 \ ---alpha_to=3.2 \ ---beta_from=0.1 \ ---beta_to=0.45 \ ---cutoff_prob=1.0 \ ---cutoff_top_n=40 \ ---checkpoint_path ${1} - -if [ $? -ne 0 ]; then - echo "Failed in tuning!" - exit 1 -fi - - -exit 0 diff --git a/examples/librispeech/s0/run.sh b/examples/librispeech/s0/run.sh index c7902a56a..af47fb9b8 100755 --- a/examples/librispeech/s0/run.sh +++ b/examples/librispeech/s0/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh best exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/librispeech/s1/cmd.sh b/examples/librispeech/s1/cmd.sh new file mode 100644 index 000000000..7b70ef5e0 --- /dev/null +++ b/examples/librispeech/s1/cmd.sh @@ -0,0 +1,89 @@ +# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ====== +# Usage: .pl [options] JOB=1: +# e.g. +# run.pl --mem 4G JOB=1:10 echo.JOB.log echo JOB +# +# Options: +# --time