From 10cd65609545cbea8e25f9ec3788feda9a85f942 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 31 Aug 2021 07:41:24 +0000 Subject: [PATCH] remove notebook --- .notebook/Linear_test.ipynb | 605 -- .notebook/WarmupLR.ipynb | 339 -- .notebook/audio_feature.ipynb | 1207 ---- .notebook/compute_cmvn_loader_test.ipynb | 793 --- .notebook/dataloader.ipynb | 389 -- .../dataloader_with_tokens_tokenids.ipynb | 1204 ---- .notebook/espnet_dataloader.ipynb | 1541 ----- .notebook/hack_api_test.ipynb | 290 - .notebook/jit_infer.ipynb | 672 --- .notebook/layer_norm_test.ipynb | 229 - .notebook/mask_and_masked_fill_test.ipynb | 449 -- .notebook/position_embeding_check.ipynb | 231 - .notebook/python_test.ipynb | 1680 ------ .notebook/train_test.ipynb | 1887 ------- .notebook/u2_confermer_model_wenet.ipynb | 4608 --------------- .notebook/u2_tansformer_model_espnet.ipynb | 1672 ------ .notebook/wenet_model.ipynb | 5015 ----------------- 17 files changed, 22811 deletions(-) delete mode 100644 .notebook/Linear_test.ipynb delete mode 100644 .notebook/WarmupLR.ipynb delete mode 100644 .notebook/audio_feature.ipynb delete mode 100644 .notebook/compute_cmvn_loader_test.ipynb delete mode 100644 .notebook/dataloader.ipynb delete mode 100644 .notebook/dataloader_with_tokens_tokenids.ipynb delete mode 100644 .notebook/espnet_dataloader.ipynb delete mode 100644 .notebook/hack_api_test.ipynb delete mode 100644 .notebook/jit_infer.ipynb delete mode 100644 .notebook/layer_norm_test.ipynb delete mode 100644 .notebook/mask_and_masked_fill_test.ipynb delete mode 100644 .notebook/position_embeding_check.ipynb delete mode 100644 .notebook/python_test.ipynb delete mode 100644 .notebook/train_test.ipynb delete mode 100644 .notebook/u2_confermer_model_wenet.ipynb delete mode 100644 .notebook/u2_tansformer_model_espnet.ipynb delete mode 100644 .notebook/wenet_model.ipynb diff --git a/.notebook/Linear_test.ipynb b/.notebook/Linear_test.ipynb deleted file mode 100644 index 5c7370cf3..000000000 --- a/.notebook/Linear_test.ipynb +++ /dev/null @@ -1,605 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "academic-surname", - "metadata": {}, - "outputs": [], - "source": [ - "import paddle\n", - "from paddle import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "fundamental-treasure", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "L = nn.Linear(256, 2048)\n", - "L2 = nn.Linear(2048, 256)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "consolidated-elephant", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import torch\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "moderate-noise", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float64\n", - "Tensor(shape=[2, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[-1.54171216, -2.61531472, -1.79881978, ..., -0.31395876, 0.56513089, -0.44516513],\n", - " [-0.79492962, 1.91157901, 0.66567147, ..., 0.54825783, -1.01471853, -0.84924090],\n", - " [-1.22556651, -0.36225814, 0.65063190, ..., 0.65726501, 0.05563191, 0.09009409],\n", - " ...,\n", - " [ 0.38615900, -0.77905393, 0.99732304, ..., -1.38463700, -3.32365036, -1.31089687],\n", - " [ 0.05579993, 0.06885809, -1.66662002, ..., -0.23346378, -3.29372883, 1.30561364],\n", - " [ 1.90676069, 1.95093191, -0.28849599, ..., -0.06860496, 0.95347673, 1.00475824]],\n", - "\n", - " [[-0.91453546, 0.55298805, -1.06146812, ..., -0.86378336, 1.00454640, 1.26062179],\n", - " [ 0.10223761, 0.81301165, 2.36865163, ..., 0.16821407, 0.29240361, 1.05408621],\n", - " [-1.33196676, 1.94433689, 0.01934209, ..., 0.48036841, 0.51585966, 1.22893548],\n", - " ...,\n", - " [-0.19558455, -0.47075930, 0.90796155, ..., -1.28598249, -0.24321797, 0.17734711],\n", - " [ 0.89819717, -1.39516675, 0.17138045, ..., 2.39761519, 1.76364994, -0.52177650],\n", - " [ 0.94122332, -0.18581429, 1.36099780, ..., 0.67647684, -0.04699665, 1.51205540]]])\n", - "tensor([[[-1.5417, -2.6153, -1.7988, ..., -0.3140, 0.5651, -0.4452],\n", - " [-0.7949, 1.9116, 0.6657, ..., 0.5483, -1.0147, -0.8492],\n", - " [-1.2256, -0.3623, 0.6506, ..., 0.6573, 0.0556, 0.0901],\n", - " ...,\n", - " [ 0.3862, -0.7791, 0.9973, ..., -1.3846, -3.3237, -1.3109],\n", - " [ 0.0558, 0.0689, -1.6666, ..., -0.2335, -3.2937, 1.3056],\n", - " [ 1.9068, 1.9509, -0.2885, ..., -0.0686, 0.9535, 1.0048]],\n", - "\n", - " [[-0.9145, 0.5530, -1.0615, ..., -0.8638, 1.0045, 1.2606],\n", - " [ 0.1022, 0.8130, 2.3687, ..., 0.1682, 0.2924, 1.0541],\n", - " [-1.3320, 1.9443, 0.0193, ..., 0.4804, 0.5159, 1.2289],\n", - " ...,\n", - " [-0.1956, -0.4708, 0.9080, ..., -1.2860, -0.2432, 0.1773],\n", - " [ 0.8982, -1.3952, 0.1714, ..., 2.3976, 1.7636, -0.5218],\n", - " [ 0.9412, -0.1858, 1.3610, ..., 0.6765, -0.0470, 1.5121]]])\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "x = np.random.randn(2, 51, 256)\n", - "print(x.dtype)\n", - "px = paddle.to_tensor(x, dtype='float32')\n", - "tx = torch.tensor(x, dtype=torch.float32)\n", - "print(px)\n", - "print(tx)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cooked-progressive", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "mechanical-prisoner", - "metadata": {}, - "outputs": [], - "source": [ - "data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n", - "t_norm_ff = data['norm_ff']\n", - "t_ff_out = data['ff_out']\n", - "t_ff_l_x = data['ff_l_x']\n", - "t_ff_l_a_x = data['ff_l_a_x']\n", - "t_ff_l_a_l_x = data['ff_l_a_l_x']\n", - "t_ps = data['ps']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "indie-marriage", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "assured-zambia", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "L.set_state_dict({'weight': t_ps[0].T, 'bias': t_ps[1]})\n", - "L2.set_state_dict({'weight': t_ps[2].T, 'bias': t_ps[3]})\n", - "\n", - "ps = []\n", - "for n, p in L.named_parameters():\n", - " ps.append(p)\n", - "\n", - "for n, p in L2.state_dict().items():\n", - " ps.append(p)\n", - " \n", - "for p, tp in zip(ps, t_ps):\n", - " print(np.allclose(p.numpy(), tp.T))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "committed-jacob", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "extreme-traffic", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "optimum-milwaukee", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "viral-indian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "# data = np.load('enc_0_ff_out.npz', allow_pickle=True)\n", - "# t_norm_ff = data['norm_ff']\n", - "# t_ff_out = data['ff_out']\n", - "# t_ff_l_x = data['ff_l_x']\n", - "# t_ff_l_a_x = data['ff_l_a_x']\n", - "# t_ff_l_a_l_x = data['ff_l_a_l_x']\n", - "# t_ps = data['ps']\n", - "TL = torch.nn.Linear(256, 2048)\n", - "TL2 = torch.nn.Linear(2048, 256)\n", - "TL.load_state_dict({'weight': torch.tensor(t_ps[0]), 'bias': torch.tensor(t_ps[1])})\n", - "TL2.load_state_dict({'weight': torch.tensor(t_ps[2]), 'bias': torch.tensor(t_ps[3])})\n", - "\n", - "# for n, p in TL.named_parameters():\n", - "# print(n, p)\n", - "# for n, p in TL2.named_parameters():\n", - "# print(n, p)\n", - "\n", - "ps = []\n", - "for n, p in TL.state_dict().items():\n", - " ps.append(p.data.numpy())\n", - " \n", - "for n, p in TL2.state_dict().items():\n", - " ps.append(p.data.numpy())\n", - " \n", - "for p, tp in zip(ps, t_ps):\n", - " print(np.allclose(p, tp))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "skilled-vietnamese", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[ 0.67277956 0.08313607 -0.62761104 ... -0.17480263 0.42718208\n", - " -0.5787626 ]\n", - " [ 0.91516656 0.5393416 1.7159258 ... 0.06144593 0.06486575\n", - " -0.03350811]\n", - " [ 0.438351 0.6227843 0.24096036 ... 1.0912522 -0.90929437\n", - " -1.012989 ]\n", - " ...\n", - " [ 0.68631977 0.14240924 0.10763275 ... -0.11513516 0.48065388\n", - " 0.04070369]\n", - " [-0.9525228 0.23197874 0.31264272 ... 0.5312439 0.18773697\n", - " -0.8450228 ]\n", - " [ 0.42024016 -0.04561988 0.54541194 ... -0.41933843 -0.00436018\n", - " -0.06663495]]\n", - "\n", - " [[-0.11638781 -0.33566502 -0.20887226 ... 0.17423287 -0.9195841\n", - " -0.8161046 ]\n", - " [-0.3469874 0.88269687 -0.11887559 ... -0.15566081 0.16357468\n", - " -0.20766167]\n", - " [-0.3847657 0.3984318 -0.06963477 ... -0.00360622 1.2360432\n", - " -0.26811332]\n", - " ...\n", - " [ 0.08230796 -0.46158582 0.54582864 ... 0.15747628 -0.44790155\n", - " 0.06020184]\n", - " [-0.8095085 0.43163058 -0.42837143 ... 0.8627463 0.90656304\n", - " 0.15847842]\n", - " [-1.485811 -0.18216592 -0.8882585 ... 0.32596245 0.7822631\n", - " -0.6460344 ]]]\n", - "[[[ 0.67278004 0.08313602 -0.6276114 ... -0.17480245 0.42718196\n", - " -0.5787625 ]\n", - " [ 0.91516703 0.5393413 1.7159253 ... 0.06144581 0.06486579\n", - " -0.03350812]\n", - " [ 0.43835106 0.62278455 0.24096027 ... 1.0912521 -0.9092943\n", - " -1.0129892 ]\n", - " ...\n", - " [ 0.6863195 0.14240888 0.10763284 ... -0.11513527 0.48065376\n", - " 0.04070365]\n", - " [-0.9525231 0.23197863 0.31264275 ... 0.53124386 0.18773702\n", - " -0.84502304]\n", - " [ 0.42024007 -0.04561983 0.545412 ... -0.41933888 -0.00436005\n", - " -0.066635 ]]\n", - "\n", - " [[-0.11638767 -0.33566508 -0.20887226 ... 0.17423296 -0.9195838\n", - " -0.8161046 ]\n", - " [-0.34698725 0.88269705 -0.11887549 ... -0.15566081 0.16357464\n", - " -0.20766166]\n", - " [-0.3847657 0.3984319 -0.06963488 ... -0.00360619 1.2360426\n", - " -0.26811326]\n", - " ...\n", - " [ 0.08230786 -0.4615857 0.5458287 ... 0.15747619 -0.44790167\n", - " 0.06020182]\n", - " [-0.8095083 0.4316307 -0.42837155 ... 0.862746 0.9065631\n", - " 0.15847899]\n", - " [-1.485811 -0.18216613 -0.8882584 ... 0.32596254 0.7822631\n", - " -0.6460344 ]]]\n", - "True\n", - "False\n" - ] - } - ], - "source": [ - "y = L(px)\n", - "print(y.numpy())\n", - "\n", - "ty = TL(tx)\n", - "print(ty.data.numpy())\n", - "print(np.allclose(px.numpy(), tx.detach().numpy()))\n", - "print(np.allclose(y.numpy(), ty.detach().numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "incorrect-allah", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "prostate-cameroon", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "governmental-surge", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 0.04476918 0.554463 -0.3027508 ... -0.49600336 0.3751858\n", - " 0.8254095 ]\n", - " [ 0.95594174 -0.29528382 -1.2899452 ... 0.43718258 0.05584608\n", - " -0.06974669]]\n", - "[[ 0.04476918 0.5544631 -0.3027507 ... -0.49600336 0.37518573\n", - " 0.8254096 ]\n", - " [ 0.95594174 -0.29528376 -1.2899454 ... 0.4371827 0.05584623\n", - " -0.0697467 ]]\n", - "True\n", - "False\n", - "True\n" - ] - } - ], - "source": [ - "x = np.random.randn(2, 256)\n", - "px = paddle.to_tensor(x, dtype='float32')\n", - "tx = torch.tensor(x, dtype=torch.float32)\n", - "y = L(px)\n", - "print(y.numpy())\n", - "ty = TL(tx)\n", - "print(ty.data.numpy())\n", - "print(np.allclose(px.numpy(), tx.detach().numpy()))\n", - "print(np.allclose(y.numpy(), ty.detach().numpy()))\n", - "print(np.allclose(y.numpy(), ty.detach().numpy(), atol=1e-5))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "confidential-jacket", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "improved-civilization", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "5e7e7c9fde8350084abf1898cf52651cfc84b17a\n" - ] - } - ], - "source": [ - "print(paddle.version.commit)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "d1e2d3b4", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " 'commit',\n", - " 'full_version',\n", - " 'istaged',\n", - " 'major',\n", - " 'minor',\n", - " 'mkl',\n", - " 'patch',\n", - " 'rc',\n", - " 'show',\n", - " 'with_mkl']" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(paddle.version)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "c880c719", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.1.0\n" - ] - } - ], - "source": [ - "print(paddle.version.full_version)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "f26977bf", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "commit: 5e7e7c9fde8350084abf1898cf52651cfc84b17a\n", - "None\n" - ] - } - ], - "source": [ - "print(paddle.version.show())" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "04ad47f6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.6.0\n" - ] - } - ], - "source": [ - "print(torch.__version__)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "e1e03830", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " '__version__',\n", - " 'cuda',\n", - " 'debug',\n", - " 'git_version',\n", - " 'hip']" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(torch.version)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "4ad0389b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'b31f58de6fa8bbda5353b3c77d9be4914399724d'" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.version.git_version" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "7870ea10", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'10.2'" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.version.cuda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "db8ee5a7", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6321ec2a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/WarmupLR.ipynb b/.notebook/WarmupLR.ipynb deleted file mode 100644 index 21abf9cbe..000000000 --- a/.notebook/WarmupLR.ipynb +++ /dev/null @@ -1,339 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "d6a0e098", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Union\n", - "\n", - "import torch\n", - "from torch.optim.lr_scheduler import _LRScheduler\n", - "\n", - "from typeguard import check_argument_types\n", - "\n", - "\n", - "class WarmupLR(_LRScheduler):\n", - " \"\"\"The WarmupLR scheduler\n", - " This scheduler is almost same as NoamLR Scheduler except for following\n", - " difference:\n", - " NoamLR:\n", - " lr = optimizer.lr * model_size ** -0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " WarmupLR:\n", - " lr = optimizer.lr * warmup_step ** 0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " Note that the maximum lr equals to optimizer.lr in this scheduler.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " optimizer: torch.optim.Optimizer,\n", - " warmup_steps: Union[int, float] = 25000,\n", - " last_epoch: int = -1,\n", - " ):\n", - " assert check_argument_types()\n", - " self.warmup_steps = warmup_steps\n", - "\n", - " # __init__() must be invoked before setting field\n", - " # because step() is also invoked in __init__()\n", - " super().__init__(optimizer, last_epoch)\n", - "\n", - " def __repr__(self):\n", - " return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n", - "\n", - " def get_lr(self):\n", - " step_num = self.last_epoch + 1\n", - " return [\n", - " lr\n", - " * self.warmup_steps ** 0.5\n", - " * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)\n", - " for lr in self.base_lrs\n", - " ]\n", - "\n", - " def set_step(self, step: int):\n", - " self.last_epoch = step" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "0d496677", - "metadata": {}, - "outputs": [], - "source": [ - "import torch.optim as optim\n", - "model = torch.nn.Linear(10, 200)\n", - "optimizer = optim.Adam(model.parameters())\n", - "scheduler = WarmupLR(optimizer, warmup_steps=25000)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "e3e3f3dc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0 0.0 -1\n" - ] - } - ], - "source": [ - "infos = {}\n", - "start_epoch = infos.get('epoch', -1) + 1\n", - "cv_loss = infos.get('cv_loss', 0.0)\n", - "step = infos.get('step', -1)\n", - "print(start_epoch, cv_loss, step)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "dc3d550c", - "metadata": {}, - "outputs": [], - "source": [ - "scheduler.set_step(step)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "e527634e", - "metadata": {}, - "outputs": [], - "source": [ - "lrs=[]\n", - "for i in range(100000):\n", - " scheduler.step()\n", - " lrs.append(scheduler.get_lr())" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "f1452db9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting matplotlib\n", - " Downloading matplotlib-3.4.1-cp38-cp38-manylinux1_x86_64.whl (10.3 MB)\n", - "\u001b[K |████████████████████████████████| 10.3 MB 575 kB/s eta 0:00:01\n", - "\u001b[?25hCollecting kiwisolver>=1.0.1\n", - " Downloading kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl (1.2 MB)\n", - "\u001b[K |████████████████████████████████| 1.2 MB 465 kB/s eta 0:00:01\n", - "\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (8.1.2)\n", - "Requirement already satisfied: numpy>=1.16 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (1.20.1)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.8.1)\n", - "Collecting cycler>=0.10\n", - " Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)\n", - "Requirement already satisfied: pyparsing>=2.2.1 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.4.7)\n", - "Requirement already satisfied: six in /workspace/wenet/venv/lib/python3.8/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n", - "Installing collected packages: kiwisolver, cycler, matplotlib\n", - "Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1\n" - ] - } - ], - "source": [ - "!pip install matplotlib\n", - "import matplotlib.pyplot as plt\n", - "\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "0f36d04f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqc0lEQVR4nO3deXxV1b338c8vCUkYkkAghJAEAhLQIJMEHHFCBa2KVkG0T7Wt1qet9ra1w9Xn3ufe1ld7b21tvVq1alut+mhJQK3Yqjig1SpCDgIyBiLTSZhCAglTyLSeP86GxjTDQZKc6ft+vXh5zjrrrLM2O+bL3mvv3zHnHCIiIu2JC/UEREQkvCkoRESkQwoKERHpkIJCREQ6pKAQEZEOJYR6Al1h0KBBLi8vL9TTEBGJKMuXL9/rnMvorF9UBEVeXh4+ny/U0xARiShmti2Yfjr1JCIiHVJQiIhIhxQUIiLSIQWFiIh0SEEhIiIdCioozGymmZWaWZmZ3d3G60lmVuS9vtTM8lq8do/XXmpmM1q0P2lme8xsTaux0s3sTTPb5P13wElsn4iInKROg8LM4oFHgMuBAuBGMyto1e1WYJ9zbhTwAHCf994CYC4wFpgJPOqNB/BHr621u4G3nXP5wNvecxERCZFgjiimAmXOuc3OuXpgHjCrVZ9ZwNPe4wXAdDMzr32ec+6oc24LUOaNh3PuPaC6jc9rOdbTwDXBb450p82VB3m3dE+opyEiPSyYoMgG/C2el3ttbfZxzjUCNcDAIN/bWqZzbqf3eBeQ2VYnM7vdzHxm5qusrAxiM+Rk3fS7pXzlqRLeWrc71FMRkR4U1ovZLvCtSm1+s5Jz7gnnXKFzrjAjo9M70OUkle05wK7aOgC+V7SSTysPhnhGItJTggmKCiC3xfMcr63NPmaWAKQBVUG+t7XdZpbljZUF6FxHGCj2lZMQZ7xy53kkJsRx+zM+DtQ1hHpaItIDggmKEiDfzEaYWSKBxemFrfosBG7xHl8PLPaOBhYCc72rokYA+cCyTj6v5Vi3AC8HMUfpRg1Nzbz4cTnTTxvMuJw0Hr7pDLZWHeZ7RatobtZX6YpEu06DwltzuBNYBKwHip1za83sXjO72uv2B2CgmZUBd+FdqeScWwsUA+uA14E7nHNNAGb2J2AJMMbMys3sVm+snwOXmtkm4BLvuYTQ4g172HuwnjmFgYPDs08ZyL9/4TTeWr+bB9/eFOLZiUh3s8A//CNbYWGhU/XY7nPb0yV8Ul7Dh3dfTEJ84N8Wzjl+uOATFiwv58G5E5k1sbNrFEQk3JjZcudcYWf9wnoxW0JvT20d75RWct3knOMhAWBm/Oza0zlzRDo/nP8JJVvbutJZRKKBgkI6tODjcpqa3fHTTi0lJcTz+Jcnk5Pem68/42PL3kMhmKGIdDcFhbTLOcd8XzlT89IZMahvm33690nkqa9MIc6Mrz61jOpD9T08SxHpbgoKaVfJ1n1s2XuIOVP++WiipeED+/K7myezo6aOrz/j40h9Uw/NUER6goJC2lXs89MvKYErxg3ptO/k4ek8eMNEVmzfxzefW059Y3MPzFBEeoKCQtp0oK6Bv36yk6smZNEnMbivVr98XBY/u3Yc75ZW8oP5usdCJFoE9xtAYs5fP9nJkYamNhexO3Lj1GHsP9zAfa9vIK13L+6dNZZAfUgRiVQKCmlTkc9P/uB+TMztf8Lv/eaFp7D/cD2Pv7eZ/n168f3LxnT9BEWkxygo5J9s2n2AFdv38+9fOO1zHw3cffmp1Bxp4DeLy0iMj+Pb0/O7eJYi0lMUFPJPin1+EuKMayZ9/rutAzfkjaO+sZlfvbkRM7jzYoWFSCRSUMhnBAoAVnDJaZkM6pd0UmPFxxm/nD0BgPvf2AgoLEQikYJCPuPt9XuoOlTPnCk5XTJe67AwM+64aFSXjC0iPUNBIZ8x3+cnMzWJ8/O77sugjoWFA365qJT6xma+e0m+roYSiRAKCjlud20d75Tu4RsXnPKZAoBdIT7OuH/2BBLijAff3kTNkQb+48oC4uIUFiLhTkEhxy1YXk6z44TvnQhWfJxx33XjSUnuxZMfbOFAXSP3XTeuy0NJRLqWgkKAYwUA/UwdkU5eOwUAu0JcnPF/rzyNtN69eOCtjRyoa+A3N00iKSG+2z5TRE6O/iknACzbUs3WqsPc0E1HEy2ZGd+5JJ//vKqAN9bt5qtPlVCr798WCVsKCgGg2FfuFQDM6rHP/Oq5I/j1nAks21LN9b/9kIr9R3rss0UkeAoK4UBdA6+u3slVE4bSO7FnTwF98Ywcnv7aVHbur+PaRz5gTUVNj36+iHROQSH8xSsAeEMn3zvRXc4dNYgF3zyHXvFxzHl8Ce9s2BOSeYhI2xQUQlGJn9GZ/ZiQkxayOYwZksJL3zqHkRl9ufXpEp5ZshXnVKZcJBwoKGLcxt0HWOnfz5zC3JDfADc4NZmi28/mojGD+Y+X13LPi6s52qhvyxMJNQVFjCsu8dMr3rj2JAoAdqW+SQk8cXMhd1x0CvNK/Nz0u6Xsqa0L9bREYpqCIobVNzbz0opAAcCBJ1kAsCvFxxk/nHEqj9x0But21HLVw39npX9/qKclErMUFDFs8YbdgQKAPXDvxOfxhfFZvPDNc0iICyxyF5f4Qz0lkZikoIhhxb5yhqQmc/7orisA2NUKhqbyyrfPo3D4AH70wid8v3gVh+sbQz0tkZiioIhRu2rqeLd0D9dNziY+zAvzpfdN5Nlbz+Rfpufz4opyZj38AWV7DoR6WiIxQ0ERo174OFAAcPbk8Dzt1Fp8nHHXpaN55mtTqT5Uz1W/+YCXVpSHeloiMUFBEYOccxT7/JzZzQUAu8O0/Axe/c40xuWk8b2iVfxowSoOHdWpKJHupKCIQUu3VLOt6nDI7sQ+WZmpyTx/25ncedEo5i8v54qH3mf5tn2hnpZI1FJQxKBin5+UpAQuP73nCgB2tYT4OH4wYwxFt59NY5Nj9mMf8us3N9LQ1BzqqYlEnaCCwsxmmlmpmZWZ2d1tvJ5kZkXe60vNLK/Fa/d47aVmNqOzMc1supl9bGYrzezvZqYvWO5CtccKAE7s+QKA3WHqiHRe++40rpmUzUNvb2L2Y0vYsvdQqKclElU6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5m+BLznnJgLPA/9+Ulson/GXVTupa2juke+d6Cmpyb349ZyJPHzTJLbsPcQVD77PUx9soblZtaJEukIwRxRTgTLn3GbnXD0wD5jVqs8s4Gnv8QJgugUKB80C5jnnjjrntgBl3ngdjemAVO9xGrDj822atKXI52dMZgrjQ1gAsLtcOX4or393GmeOTOcnr6xjzuNL+LTyYKinJRLxggmKbKDlLbHlXlubfZxzjUANMLCD93Y05m3Aq2ZWDnwZ+HlbkzKz283MZ2a+ysrKIDZDSncdYJV/P3OmhL4AYHfJSuvNU1+Zwq9mT2DTnoNc/uD7/PbdT2nU2oXI5xaOi9nfA65wzuUATwG/bquTc+4J51yhc64wIyN87ywOJ8W+8CoA2F3MjOsm5/DmXedz0ZgM7nt9A9c++iHrd9aGemoiESmYoKgAWp7QzvHa2uxjZgkEThlVdfDeNtvNLAOY4Jxb6rUXAecEtSXSoWMFAC8tyCS9b2Kop9MjBqck89j/mswjN53Bjv1HuPI3f+e/Xl2v+y5ETlAwQVEC5JvZCDNLJLA4vbBVn4XALd7j64HFLvCtMwuBud5VUSOAfGBZB2PuA9LMbLQ31qXA+s+/eXLM2+t3U32ontlRtIgdDDPjC+OzeOuuC5hTmMMT721m+q/+xmurd+qLkUSClNBZB+dco5ndCSwC4oEnnXNrzexewOecWwj8AXjWzMqAagK/+PH6FQPrgEbgDudcE0BbY3rtXwdeMLNmAsHxtS7d4hhV5PMHCgDmx+ZpugF9E/nvL45ndmEu//bSGr753MdcMDqDn1w9NuLuThfpaRYN/6oqLCx0Pp8v1NMIWztrjnDuzxfzrQtH8YMZY0I9nZBrbGrm2Y+28as3NlLf1Mw3LjiFb1wwkj6Jnf67SSSqmNly51xhZ/3CcTFbutgLy70CgIU5oZ5KWEiIj+Or547g7e9fwMyxQ3jo7U1cfP/fePHjct17IdIGBUWUa252FPvKOWtkOsMH6hRLS5mpyTx04yQWfONsMlOTuKt4Fdc8+gG+rdWhnppIWFFQRLmlW6rZXh25BQB7QmFeOi9961weuGECe2qPcv1jS7jj+Y/xVx8O9dREwoJOyka5+T4/KcmRXQCwJ8TFGddOymHG2CE8/rfNPP7ep7y5djdfOmsYd1w0ikFh9J3iIj1NRxRRrLaugVfX7OTqCUNJ7hX5BQB7Qp/EBL536Wje+cGFXDspm6c/3MoFv3iHX7+5kQN1DaGenkhIKCii2CurdgQKAOq00wnLSuvNfdeP543vXcAFYzJ46O1NnP+Ld/j9+5upa2gK9fREepSCIooVl/g5dUgK47KjrwBgTxk1uB+PfmkyC+88l9Oz0/jpX9dz0f3v8vzS7dQ3qn6UxAYFRZTasKuWVeU1zCmM3gKAPWl8Tn+evfVMnr/tTDJTk/k/L63mwl++w7NLtnK0UUcYEt0UFFGquKScXvHGNVFeALCnnTNqEC996xye+dpUsvr35v++vJYLfvEuf/xgi05JSdRSUEShQAHAci4rGBIzBQB7kplx/ugMFnzjbJ677UyGpffhx6+sY5q3hnG4XkUHJbro8tgo9Nb63ew73KA7sbuZmXHuqEGcO2oQH22u4qG3N/HTv67n4XfKuPms4dx8Tp4uq5WooKCIQkUlfrLSkpkWowUAQ+GskQM5a+RAlm+r5rG/beahxWU8/t5mrp+cw9enjVThQYloCooos2P/Ed7bVMmdF40iPk6L2D1t8vB0fndzOmV7DvL79zcz31fO88u2M3PsEG4/fySThg0I9RRFTpiCIsq8sLwc52D2ZN07EUqjBvfj59eN565LR/PHD7fy/z7axmtrdjE1L51bzsnjsrGZ9IrXEqFEBpUZjyLNzY4L73+XnAG9ef7rZ4V6OtLCwaONzFu2naeXbMVffYQhqcl8+ezhzJ2Sy0CtY0iIqMx4DPpoSxXbqw8zJ8a+xS4S9EtK4LZpI3n3Bxfxu5sLGTW4H79cVMrZP1/M94tXsbq8JtRTFGmXTj1Fkfm+clKSE5h5+pBQT0XaER9nXFqQyaUFmZTtOcDTH27jhY/LeeHjcs4Y1p8vnz2cy0/PUm0uCSs69RQlao40MPVnbzG7MIefXjMu1NORE1Bb18ACXznPLNnK1qrDpPXuxbWTsrlx6jDGDEkJ9fQkigV76klHFFHilVU7ONrYzA2Fw0I9FTlBqcm9+Np5I/jKOXl8tLmK55dt57ml2/jjh1s5Y1h/bpw6jCvHD6V3oo4yJDR0RBElrn7479Q3NvPad6aptlMUqDp4lBc/ruBPJdvZXHmIlOQErpmYzQ1Tchk7NFX7WLqEjihiyPqdtXxSXsN/XlWgXyBRYmC/JL5+/khumzaCZVuq+dOy7RT5/Dz70TbGZKZw3eRsZk3MJjM1OdRTlRigoIgCxT4/ifFxXDNRBQCjjZlx5siBnDlyID8+XM8rn+zkxY/L+a9XN/Dz1zZwXn4G152RzWUFQ3RqSrqNgiLCHW1s4s8rKrh0bCYDVAAwqvXvk8iXzxrOl88azqeVB3np4wpeWlHBd+atpF9SAl8Yl8UXz8hmSl46cborX7qQgiLCvbVuD/sON+jeiRhzSkY/fjBjDHddOpqPtlTx4scV/OWTHRT5AnW+rhyfxZXjhzI+J02nI+WkaTE7wt385DLKdh/g/X+9WLWdYtzh+kbeWLubv3yyg79trKShyTEsvQ9XTcjiqglDGZOZotCQz9BidgzYsf8I72+q5NsqAChAn8QErpmUzTWTsqk53MCitbt45ZMdPPa3zTzyzqeMGtyPq8YP5coJWZyS0S/U05UIoqCIYAuOFQDUaSdpJa1PL+ZMyWXOlFz2HjzKa2t28ZdVO/iftzfywFsbOXVICpeNHcKMsZkUZOlyW+mYTj1FqOZmxwX3v8Ow9D48d5sKAEpwdtXU8erqnby+dhe+rdU0O8hN782MgiHMOH0IZwwboKPTGKJTT1Huo81V+KuP8IPLxoR6KhJBhqQl87XzRvC180aw9+BR3lq3m0Vrd/HMkm38/u9bGNQviUsLMpl5+hDOHjmQxATVDRUFRcQq9vlJTU5gxlgVAJTPZ1C/JOZOHcbcqcM4UNfAO6WVLFq7i5dXVvCnZdtJSUrg/NEZXHTqYC4ck6GvdY1hQQWFmc0EHgTigd87537e6vUk4BlgMlAF3OCc2+q9dg9wK9AE/ItzblFHY1rgZOlPgdnee37rnHvo5DYzutQcaeC1NbuYU5irKqPSJVKSe3H1hKFcPWEodQ1NfFC2lzfW7uad0j38dfVOzGBCTn+mnzqYi04drDIiMabToDCzeOAR4FKgHCgxs4XOuXUtut0K7HPOjTKzucB9wA1mVgDMBcYCQ4G3zGy09572xvwKkAuc6pxrNrPBXbGh0WThsQKAU7SILV0vuVc800/LZPppmTQ3O9btrOXt9XtYXLqHX725kV+9uZEhqclcdGoGF5+aybmjBtInUScnolkwe3cqUOac2wxgZvOAWUDLoJgF/Nh7vAB42DsymAXMc84dBbaYWZk3Hh2M+U3gJudcM4Bzbs/n37zoVFzi57SsVMYOTQ31VCTKxcUZp2encXp2Gt+5JJ89B+p4t7SSdzbsYeHKHfxpmZ/EhDim5qUzLX8Q54/O4NQhul8j2gQTFNmAv8XzcuDM9vo45xrNrAYY6LV/1Oq9xwoStTfmKQSORq4FKgmcrtrUelJmdjtwO8CwYbFTWnvdjlpWV9TwYxUAlBAYnJLMnMJc5hTmUt/YTMnWahZv2MP7myr579c28N+vbSAjJYlpowYxbfQgzhuVQUaK1jYiXTgeLyYBdc65QjP7IvAkMK11J+fcE8ATELg8tmenGDrHCgDOUgFACbHEhDjOHTWIc0cNAgKX3r6/qZL3N+3l3Y2VvLiiAoDTslI5P38Q0/IzKMwboHW1CBRMUFQQWDM4Jsdra6tPuZklAGkEFrU7em977eXAi97jl4CngphjTDja2MSfV1ZwmQoAShgakpbM7MJcZhfmHl/beG9TJe9v3MuTH2zh8fc2k5QQR2HeAM4eOZCzRg5kfE5/XYIbAYIJihIg38xGEPhlPhe4qVWfhcAtwBLgemCxc86Z2ULgeTP7NYHF7HxgGWAdjPln4CJgC3ABsPFzb12UeXPdbvarAKBEgJZrG9+6cBSHjjaybEs172/ay5LNVdz/RuB/69694gPBccpAzh45kHHZaSTEKzjCTadB4a053AksInAp65POubVmdi/gc84tBP4APOstVlcT+MWP16+YwCJ1I3CHc64JoK0xvY/8OfCcmX0POAjc1nWbG9mKSvxk9+99/FBfJFL0TUrgIu/SWoB9h+pZuqWKJZ9WsWRzFb94vRSAfkkJTDkeHIMoGJqqO8XDgEp4RIiK/Uc4777FfPvifO66dHTnbxCJIHsPHuWjzf8Ijs2VhwBISUrgjOEDmJI3gMK8dCbm9tcaRxdSCY8os8BXDsDsyTkhnolI1xvUL4krxw/lyvFDAdhdW8dHm6tYtqUa39Z9x09V9YoPnNKakpdO4fBAeKRrva7bKSgiQHOzY/5yP+eeMojc9D6hno5It8tMTWbWxOzjV/ftP1zP8m37KNm6D9/Wav74wVaeeG8zAKdk9A0Ehxcewwf20aXjXUxBEQGWbK6ifN8RfjhDBQAlNvXvk3j8bnGAuoYmVlfUULI1cMTx6uqdzCsJ3JqV3jeRibn9mZTbn0nDBjA+N43U5F6hnH7EU1BEABUAFPms5F7xTMlLZ0peOhA46t605yC+bdWs3L6fFf79LN4QKOpgBqMy+gXCY9gAJub2Z3RmP11ddQIUFGGu5nCgAODcKSoAKNKeuDhjzJAUxgxJ4UtnDgcCxTM/Kd/Piu37Wenfz1vrdzN/eWCtr09iPONz0piYGwiOCblpDElN1imrdigowtzCVRXUNzbr3gmRE5TWuxfT8jOYlp8BgHOObVWHWenfz4rt+1jp38/v399MY3Pgys9B/RI5PTuN8d79H+NyFB7HKCjCXJHPT0FWKqdnp4V6KiIRzczIG9SXvEF9uWZSYJG8rqGJtTtqWVNRw+qKGlaX1/Dexkq87GBQvyTGZacyLqc/47LTGJ+TRmZqcgi3IjQUFGFs7Y4a1lTU8pOrx4Z6KiJRKblXPJOHD2Dy8AHH247UN7FuZyA0VlfUsrpiP39rER4ZKUmM8446CrwqzjkDekf1kYeCIozN95WTmBDHrIlDQz0VkZjROzGeycPTmTw8/Xjb4fpG1u+s5ZPywJHHmooa3i3dczw8UpISOC0rldOyUjgtK5WCoamMzkyJmnVFBUWYqmto4qUVFcwYO4T+fXRDkUgo9UlM+KfwOFLfROnuA6zbUcv6nbWs21nLguXlHKpvAiDO4JSMfseD41iQDE6JvFNXCoow9ea63dQcaWBOoe7EFglHvRPjmZjbn4m5/Y+3NTc7tlcfZv3Of4TH8m37WLhqx/E+g/olcVpWCmMyUxg9JIXRmSnkD+5H36Tw/XUcvjOLccU+rwDgKSoAKBIp4uL+sWB++bis4+37D9ezfueB4+Gxfmctz360jaONzcf75Kb3ZkxmCvmZXohkpnDK4L4kJYT+9JWCIgyV7zvM38v28p3p+cSpcqZIxOvfJzFQEfeUgcfbmpod/urDlO4+wMZdByjdfYBNuw/ybmnl8Ut24+OMvIF9GO0Fx5ghKYzO7EfewL49esOggiIMLfBuCrpeBQBFolZ8i6OPllUX6hub2Vp1iI0tAmTDrgMsWrvr+OJ5YnwcIwb1ZdTgftx9+andXgNOQRFmmpsd833lnDdqEDkDVABQJNYkJsQdP4Jg/D/a6xqaKNtzMBAguw9Stucga3fU9Mg3BCoowsyHn1ZRsf8I/3r5qaGeioiEkeRe8ce/NbCnqSpWmCn2+Unr3YvLCjJDPRUREUBBEVZqDjfw+tpdXDNxaNTcqCMikU9BEUZePlYAcIoKAIpI+FBQhJGiEj9jh6YydqgKAIpI+FBQhIk1FTWs3VHLDTqaEJEwo6AIE/N9/kABwAnZoZ6KiMhnKCjCQF1DE39euYOZY4eQ1kff7Ssi4UVBEQbeOF4AUKedRCT8KCjCQHGJn5wBvTmnRR0YEZFwoaAIMX/1YT74dC+zJ+eqAKCIhCUFRYgdLwCo750QkTCloAih5mbHguWBAoDZ/XuHejoiIm1SUITQB5/upWL/ES1ii0hYU1CEULGvnP59enHZWBUAFJHwpaAIkf2H61m0dhfXTMwOi686FBFpT1BBYWYzzazUzMrM7O42Xk8ysyLv9aVmltfitXu89lIzm3ECYz5kZgc/53aFvZdX7ggUANRpJxEJc50GhZnFA48AlwMFwI1mVtCq263APufcKOAB4D7vvQXAXGAsMBN41MziOxvTzAqBASe5bWGtqMTP6dmpFAxNDfVUREQ6FMwRxVSgzDm32TlXD8wDZrXqMwt42nu8AJhuZua1z3POHXXObQHKvPHaHdMLkV8CPzq5TQtfaypqWLezlht0NCEiESCYoMgG/C2el3ttbfZxzjUCNcDADt7b0Zh3Agudczs7mpSZ3W5mPjPzVVZWBrEZ4aPYKwB4tQoAikgECKvFbDMbCswGftNZX+fcE865QudcYUZGRvdProvUNTTx5xUVXH66CgCKSGQIJigqgJbnSHK8tjb7mFkCkAZUdfDe9tonAaOAMjPbCvQxs7IgtyUiLFq7i9q6Ri1ii0jECCYoSoB8MxthZokEFqcXtuqzELjFe3w9sNg557z2ud5VUSOAfGBZe2M65/7qnBvinMtzzuUBh70F8qhR7POTm96bs0eqAKCIRIaEzjo45xrN7E5gERAPPOmcW2tm9wI+59xC4A/As96//qsJ/OLH61cMrAMagTucc00AbY3Z9ZsXXvzVh/mgrIq7Lh2tAoAiEjE6DQoA59yrwKut2v6jxeM6AmsLbb33Z8DPghmzjT79gplfpJi/vBwzuG6yCgCKSOQIq8XsaNbU7Fjg8zMtP0MFAEUkoigoesgHZXvZUVPHHJUTF5EIo6DoIcU+P/379OLSAhUAFJHIoqDoAfsO1fPG2t0qACgiEUlB0QNeXllBfZMKAIpIZFJQdDPnHEW+csZlp6kAoIhEJAVFN1tTUcv6nbXMmaKjCRGJTAqKblbs85OUEMfVE4aGeioiIp+LgqIb1TU08eeVXgHA3ioAKCKRSUHRjRat3cWBukaddhKRiKag6EZFJYECgGeNUAFAEYlcCopu4q8+zIefVjFncq4KAIpIRFNQdJP5Pr8KAIpIVFBQdIOmZseC5eWcn5/BUBUAFJEIp6DoBn8/XgBQi9giEvkUFN2g2OdnQJ9eXFIwONRTERE5aQqKLrbvUD1vrt3NNZNUAFBEooOCoou9tCJQAPAG3TshIlFCQdGFnHMU+/yMz0nj1CEqACgi0UFB0YVWV9SwYdcBLWKLSFRRUHShYwUAr1IBQBGJIgqKLlLX0MTLK3dwxbgsFQAUkaiioOgir6/xCgDqtJOIRBkFRRcpKvEzLL0PZ45ID/VURES6lIKiC2yvOsySzVXMKcxRAUARiToKii4wf7mfOBUAFJEopaA4SccLAI7OICtNBQBFJPooKE7S+5sq2akCgCISxRQUJ2m+r5z0volcclpmqKciItItFBQnofpQPW+s28U1E7NJTNBfpYhEp6B+u5nZTDMrNbMyM7u7jdeTzKzIe32pmeW1eO0er73UzGZ0NqaZPee1rzGzJ80sbO9ee2lFBQ1NTgUARSSqdRoUZhYPPAJcDhQAN5pZQatutwL7nHOjgAeA+7z3FgBzgbHATOBRM4vvZMzngFOBcUBv4LaT2sJu4pxjvs/PhJw0xgxJCfV0RES6TTBHFFOBMufcZudcPTAPmNWqzyzgae/xAmC6mZnXPs85d9Q5twUo88Zrd0zn3KvOAywDwvKa00/KvQKAOpoQkSgXTFBkA/4Wz8u9tjb7OOcagRpgYAfv7XRM75TTl4HX25qUmd1uZj4z81VWVgaxGV2r2OcnuZcKAIpI9AvnFdhHgfecc++39aJz7gnnXKFzrjAjI6NHJ3akvomFK3dwxelZpCaH7RKKiEiXSAiiTwXQ8vxKjtfWVp9yM0sA0oCqTt7b7phm9p9ABvC/g5hfj3t97U4OHG3UaScRiQnBHFGUAPlmNsLMEgksTi9s1WchcIv3+HpgsbfGsBCY610VNQLIJ7Du0O6YZnYbMAO40TnXfHKb1z2KSvwMH6gCgCISGzo9onDONZrZncAiIB540jm31szuBXzOuYXAH4BnzawMqCbwix+vXzGwDmgE7nDONQG0Nab3kY8B24AlgfVwXnTO3dtlW3yStlUd4qPN1fxwxhi8+YmIRLVgTj3hnHsVeLVV23+0eFwHzG7nvT8DfhbMmF57UHMKlfm+8kABwDPC8mIsEZEuF86L2WHnWAHAC0ZnMCQtOdTTERHpEQqKE/Depkp21aoAoIjEFgXFCZjv85PeN5HpKgAoIjFEQRGkqoNHeXPdbq6dpAKAIhJb9BsvSMcKAOq0k4jEGgVFEJxzFPv8TMjtrwKAIhJzFBRBWFVew8bdB7lBRxMiEoMUFEH4RwHArFBPRUSkxykoOnGkvolXVu7ginFZpKgAoIjEIAVFJ15bEygAqNNOIhKrFBSdKCrxkzewD1NVAFBEYpSCogNb9x5i6ZZqZhfmqgCgiMQsBUUH5i/3qwCgiMQ8BUU7jhUAvHDMYBUAFJGYpqBox3sbK9lde5Q5hTqaEJHYpqBoR1GJn4F9E7n4VBUAFJHYpqBoQ9XBo7y1XgUARURAQdGml1ZU0NjsmDNF906IiCgoWnHOUVTiZ2Juf0ZnqgCgiIiCopWV/v1s2nOQG3Q0ISICKCj+SbGvnN694rlyvAoAioiAguIzDtc38soqFQAUEWlJQdHCa6t3cfBoo047iYi0oKBoocjnZ8SgvkzJGxDqqYiIhA0FhWfL3kMs21LN7MIcFQAUEWlBQeGZ71MBQBGRtigogMamZl74uJyLxgwmM1UFAEVEWlJQAO9tChQAnK1vsRMR+ScKCgIFAAf1S2T6aYNDPRURkbAT80Gx9+BR3l6/h2snZdMrPub/OkRE/knM/2Z86eNAAUDdOyEi0raggsLMZppZqZmVmdndbbyeZGZF3utLzSyvxWv3eO2lZjajszHNbIQ3Rpk3ZuJJbmO7nHMU+/ycMaw/owarAKCISFs6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5n3AA95Y+7yxu8UKrwDgHC1ii4i0K5gjiqlAmXNus3OuHpgHzGrVZxbwtPd4ATDdAnetzQLmOeeOOue2AGXeeG2O6b3nYm8MvDGv+dxb14n5Pn+gAOCEod31ESIiES+YoMgG/C2el3ttbfZxzjUCNcDADt7bXvtAYL83RnufBYCZ3W5mPjPzVVZWBrEZ/2xYel++cm4e/ZISPtf7RURiQcT+hnTOPQE8AVBYWOg+zxjfvPCULp2TiEg0CuaIogJoeRI/x2trs4+ZJQBpQFUH722vvQro743R3meJiEgPCiYoSoB872qkRAKL0wtb9VkI3OI9vh5Y7JxzXvtc76qoEUA+sKy9Mb33vOONgTfmy59/80RE5GR1eurJOddoZncCi4B44Enn3FozuxfwOecWAn8AnjWzMqCawC9+vH7FwDqgEbjDOdcE0NaY3kf+KzDPzH4KrPDGFhGRELHAP+IjW2FhofP5fKGehohIRDGz5c65ws76xfyd2SIi0jEFhYiIdEhBISIiHVJQiIhIh6JiMdvMKoFtn/Ptg4C9XTidSKBtjg3a5uh3sts73DmX0VmnqAiKk2FmvmBW/aOJtjk2aJujX09tr049iYhIhxQUIiLSIQWFV1gwxmibY4O2Ofr1yPbG/BqFiIh0TEcUIiLSIQWFiIh0KKaDwsxmmlmpmZWZ2d2hns+JMLNcM3vHzNaZ2Voz+47Xnm5mb5rZJu+/A7x2M7OHvG39xMzOaDHWLV7/TWZ2S4v2yWa22nvPQ95X1Yac973rK8zsL97zEWa21JtnkVe6Hq+8fZHXvtTM8lqMcY/XXmpmM1q0h93PhJn1N7MFZrbBzNab2dnRvp/N7Hvez/UaM/uTmSVH2342syfNbI+ZrWnR1u37tb3P6JBzLib/EChv/ikwEkgEVgEFoZ7XCcw/CzjDe5wCbAQKgF8Ad3vtdwP3eY+vAF4DDDgLWOq1pwObvf8O8B4P8F5b5vU1772Xh3q7vXndBTwP/MV7XgzM9R4/BnzTe/wt4DHv8VygyHtc4O3vJGCE93MQH64/EwS+O/4273Ei0D+a9zOBrz/eAvRusX+/Em37GTgfOANY06Kt2/dre5/R4VxD/T9BCH8YzwYWtXh+D3BPqOd1EtvzMnApUApkeW1ZQKn3+HHgxhb9S73XbwQeb9H+uNeWBWxo0f6ZfiHczhzgbeBi4C/e/wR7gYTW+5XA952c7T1O8PpZ6319rF84/kwQ+LbILXgXnrTef9G4nwkEhd/75Zfg7ecZ0bifgTw+GxTdvl/b+4yO/sTyqadjP4zHlHttEcc71J4ELAUynXM7vZd2AZne4/a2t6P28jbaQ+1/gB8Bzd7zgcB+51yj97zlPI9vm/d6jdf/RP8uQmkEUAk85Z1u+72Z9SWK97NzrgK4H9gO7CSw35YT3fv5mJ7Yr+19RrtiOSiigpn1A14Avuucq235mgv8kyFqrn82syuBPc655aGeSw9KIHB64rfOuUnAIQKnC46Lwv08AJhFICSHAn2BmSGdVAj0xH4N9jNiOSgqgNwWz3O8tohhZr0IhMRzzrkXvebdZpblvZ4F7PHa29vejtpz2mgPpXOBq81sKzCPwOmnB4H+Znbsa31bzvP4tnmvpwFVnPjfRSiVA+XOuaXe8wUEgiOa9/MlwBbnXKVzrgF4kcC+j+b9fExP7Nf2PqNdsRwUJUC+dyVFIoFFsIUhnlPQvCsY/gCsd879usVLC4FjVz7cQmDt4lj7zd7VE2cBNd7h5yLgMjMb4P1L7jIC5293ArVmdpb3WTe3GCsknHP3OOdynHN5BPbXYufcl4B3gOu9bq23+djfxfVef+e1z/WulhkB5BNY+Au7nwnn3C7Ab2ZjvKbpBL6DPmr3M4FTTmeZWR9vTse2OWr3cws9sV/b+4z2hXLRKtR/CFxJsJHAFRD/Fur5nODczyNwyPgJsNL7cwWBc7NvA5uAt4B0r78Bj3jbuhoobDHW14Ay789XW7QXAmu89zxMqwXVEG//hfzjqqeRBH4BlAHzgSSvPdl7Xua9PrLF+//N265SWlzlE44/E8BEwOft6z8TuLolqvcz8BNggzevZwlcuRRV+xn4E4E1mAYCR4639sR+be8zOvqjEh4iItKhWD71JCIiQVBQiIhIhxQUIiLSIQWFiIh0SEEhIiIdUlCIiEiHFBQiItKh/w/uhegfvR+Q7QAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "xs = list(range(100000))\n", - "plt.plot(xs, lrs)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "4f4e282c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/wenet/venv/lib/python3.8/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "from typing import Union\n", - "\n", - "from paddle.optimizer.lr import LRScheduler\n", - "from typeguard import check_argument_types\n", - "\n", - "class WarmupLR(LRScheduler):\n", - " \"\"\"The WarmupLR scheduler\n", - " This scheduler is almost same as NoamLR Scheduler except for following\n", - " difference:\n", - " NoamLR:\n", - " lr = optimizer.lr * model_size ** -0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " WarmupLR:\n", - " lr = optimizer.lr * warmup_step ** 0.5\n", - " * min(step ** -0.5, step * warmup_step ** -1.5)\n", - " Note that the maximum lr equals to optimizer.lr in this scheduler.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " warmup_steps: Union[int, float]=25000,\n", - " learning_rate=1.0,\n", - " last_epoch=-1,\n", - " verbose=False):\n", - " assert check_argument_types()\n", - " self.warmup_steps = warmup_steps\n", - " super().__init__(learning_rate, last_epoch, verbose)\n", - "\n", - " def __repr__(self):\n", - " return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n", - "\n", - " def get_lr(self):\n", - " step_num = self.last_epoch + 1\n", - " return self.base_lr * self.warmup_steps**0.5 * min(\n", - " step_num**-0.5, step_num * self.warmup_steps**-1.5)\n", - "\n", - " def set_step(self, step: int):\n", - " self.step(step)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "8c40b202", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-1\n" - ] - } - ], - "source": [ - "sc = WarmupLR(warmup_steps=25000, learning_rate=0.001)\n", - "print(step)\n", - "#sc.set_step(step)\n", - "sc.set_step(0)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "ecbc7e37", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqaUlEQVR4nO3de3xU9Z3/8dcnCUm4JIGEEAIBEiCAQW4SEG94F7QqagGhu9Varb9a3W51267+tr/dtrvdVevW1VardrVaa4WAN7QqKqJ4QchwvwYiAZMQICQQ7uT2/f0xB4xpLoMkmcnM+/l48GDmO99z5ns4Yd4553vOZ8w5h4iISHOigj0AEREJbQoKERFpkYJCRERapKAQEZEWKShERKRFMcEeQFvo3bu3y8zMDPYwREQ6lRUrVux1zqW21i8sgiIzMxOfzxfsYYiIdCpmtiOQfjr1JCIiLVJQiIhIixQUIiLSIgWFiIi0SEEhIiItCigozGyqmRWYWaGZ3dvE63FmNtd7fZmZZTZ47T6vvcDMpjRof8bM9pjZ+kbrSjazd81sq/d3r9PYPhEROU2tBoWZRQOPAVcCOcBsM8tp1O1WYJ9zbijwMPCAt2wOMAsYCUwFHvfWB/Cs19bYvcAi51w2sMh7LiIiQRLIEcVEoNA5t805Vw3MAaY16jMNeM57PB+41MzMa5/jnDvunCsCCr314ZxbAlQ28X4N1/UccF3gmyPtaVv5IT4o2BPsYYhIBwskKPoDxQ2el3htTfZxztUCVUBKgMs2luacK/Me7wLSmupkZrebmc/MfOXl5QFshpyuWU99xnf+mM+iTbuDPRQR6UAhPZnt/N+q1OQ3KznnnnLO5TrnclNTW70DXU7T1t0H2XPwOAA/mrOaz8sPBXlEItJRAgmKUmBAg+cZXluTfcwsBkgCKgJctrHdZpburSsd0LmOEJDnKyYmynj9rvPpEhPF7X/ycfBYTbCHJSIdIJCgyAeyzSzLzGLxT04vaNRnAXCz93g68L53NLAAmOVdFZUFZAPLW3m/huu6GXgtgDFKO6qpq+fllaVcdkYaozKSeOxbZ7G94gj35K2hvl5fpSsS7loNCm/O4S5gIbAJyHPObTCzX5rZtV63p4EUMysE7sG7Usk5twHIAzYCbwN3OufqAMzsRWApMNzMSszsVm9d9wOXm9lW4DLvuQTRok17qDhczcwJGQCcMySFf7nqDN7duJtH398a5NGJSHsz/y/+nVtubq5T9dj2c+uz+azfWcUn/3wJMdH+3y2cc/x43lpeWlnCI7PGMm1sa9coiEioMbMVzrnc1vqF9GS2BN/uA8dYXLCHb56VcTIkAMyM/7zhTCZmJfOTeWvJ397Ulc4iEg4UFNKil1aWUO9gZu6Av3ktLiaap749noxeXbn9Tz627z0chBGKSHtTUEiznHPM85UwMSuZzN7dm+zTs1ssf7xlAmbGLc/ms+9wdQePUkTam4JCmrW8qJKivYe5sYmjiYYGpXTnDzeNp3T/Ub73Jx9Hq+s6aIQi0hEUFNKsPF8JPeJiuHJU31b7jh+UzP/cOJYVX+zjBy+soKauvgNGKCIdQUEhTTp4rIY315VxzZh+dIsN7KvVrxqVzn9eP4rFBeX8eJ7usRAJF4F9AkjEeWNtGUdr6rhxQsunnRqbPXEg+45U8+DbBSR17cIvrh2Jvz6kiHRWCgpp0tz8Yoal9WBMRtIpL3vHhUPYf6SGp5Zso2fXLtxzxfB2GKGIdBQFhfyNLbsPsrp4Pz/7xhlf62jAzLjvyhFUHanh0fcLiY2J4q5LstthpCLSERQU8jfy8ovpEm1cP+7r323tvyFvFDV19Tz0zhbMjDsvHtqGoxSRjqKgkK+orq3nlVX+AoApPeJOa13RUcavZ4zBAb9eWACgsBDphBQU8hXvb97tLwDYyr0TgYqOMh6aMQbwh4UZ/OAihYVIZ6KgkK/I85XQNzGeycPa7sugToSFc44H3y6gptbxw0uH6mookU5CQSEn7ao6xgcFe7jjoiFER7Xth3h0lPHfM8cSEx3Fw+9toepoDT/7xhlEtfH7iEjbU1DISScKAM4Y3zannRqLjjIe/OZoEuJjeOaTIg4cq+H+G0Z9pSqtiIQeBYUAJwoAFnN2CwUA20JUlPGvV+eQ1LUL//PeVg4dq+WR2WOJi4lut/cUkdOjX+UEgGVFlWyvOHLKd2J/HWbGjy4bxr9encPbG3bx3WfzOaDv3xYJWQoKASDPV0xCXAxXnpneYe/53fOz+O8ZY1i2rZIZv1/Kzv1HO+y9RSRwCgrhwIkCgGP70TW2Y08BfXN8Bs/eMpGd+49y3WOfsL60qkPfX0Rap6AQ3lhTxrGa+ja7d+JUnZ/dm3l3nENMlHHjk0tZXLAnKOMQkaYpKIS5vmKGpyV8rQKAbWVE30ReufM8BqV057bnfDy/dDvOqUy5SChQUES4gl0HWVO8n5kTBgT9Bri0xHjyvn8OFw5L5f+9toH/+8o6qmv1BUgiwaagiHB5vtMvANiWesTF8IebcvnBRUN4cXkxs//wGXsOHgv2sEQimoIigp0oAHh5ThrJ3WODPZyToqOMn04dwe++NY6NOw9w7W8/YW3J/mAPSyRiKSgi2KJNu6k8XM2MIE1it+bq0f2Yf8c5REcZ059YyjxfcbCHJBKRFBQRLM9X7C8AmN12BQDb2sh+SSy46zxyB/XiJ/PX8uN5azhaXRfsYYlEFAVFhNpVdYwPt5QzfXxGmxcAbGspPeJ4/taz+eElQ3lpZQnTHvuYwj0Hgz0skYihoIhQJwsA5mYEeygBiY4y7rliOM/dMpGKQ9Vc+7tPeGVVSbCHJRIRFBQRqL7ekecrZtLgZAaltF8BwPYweVgqf/3hBZzZL4m7567hp/PXcPh4bbCHJRLWFBQRaPn2SnZ0UAHA9tA3KZ6/fO9sfnDREOatKOGqRz9i5Rf7gj0skbCloIhAefn+AoBTR3ZcAcC2FhMdxU+njmDO9yZRW+eY8cRSHn53C7V1ukFPpK0FFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTWlunmV1qZivNbLWZfWxm+oLlNnTgWA1vri/j2iAUAGwPZw9O4a0fXcC0Mf14ZNFWpj+xlKK9h4M9LJGw0mpQmFk08BhwJZADzDaznEbdbgX2OeeGAg8DD3jL5gCzgJHAVOBxM4tuZZ2/B/7OOTcW+Avws9PaQvmK19fsDGoBwPaQGN+F39w4lt/OHse28kNc9chHPPtJEfX1qhUl0hYCOaKYCBQ657Y556qBOcC0Rn2mAc95j+cDl5q/cNA0YI5z7rhzrggo9NbX0jodkOg9TgJ2fr1Nk6bk5Rczom8Co4NYALC9XDOmHwvvnszErGR+/vpGZj65lG3lh4I9LJFOL5Cg6A80vCW2xGtrso9zrhaoAlJaWLaldd4GvGlmJcC3gfubGpSZ3W5mPjPzlZeXB7AZsnnXAdaUVDEzN/gFANtLelJXnr1lAg/NGMOW3QeZ+shHPPHh55q7EDkNoTiZfTdwlXMuA/gj8JumOjnnnnLO5TrnclNTQ/fO4lCSl19Cl2jjuhApANhezIzp4zN4754LuXh4Kve/tZkbfv8pm3cdCPbQRDqlQIKiFGh4QjvDa2uyj5nF4D9lVNHCsk22m1kqMMY5t8xrnwucG9CWSIv8BQBLuCKnb0gVAGxPfRLjeeLvx/O7b42jdN9RvvHox/znm5t034XIKQokKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzv/N86swCY5V0VlQVkA8tbWOc+IMnMhnnruhzY9PU3T054b9Nu9h2p6TR3YrcVM+Pq0f14754LmTE+g6eWbOOy33zIW+vK9MVIIgGKaa2Dc67WzO4CFgLRwDPOuQ1m9kvA55xbADwNPG9mhUAl/g9+vH55wEagFrjTOVcH0NQ6vfbvAS+ZWT3+4Phum25xhMrzFZOeFM8FIVwAsD316h7L/d8czYzcAfzs1fXc8cJKLhyWyi+uHUlm7851d7pIR7Nw+K0qNzfX+Xy+YA8jZJVVHeW8+9/nzouH8k9XDA/2cIKutq6ePy3dwW/e3UJ1XT3fv3AI379wMN1iW/29SSSsmNkK51xua/1CcTJb2thLK7wCgOPD596J0xETHcV3z89i0T9dyJSRfXl00VYueehDXl5ZonsvRJqgoAhz/gKAJZwzOIWBKd2CPZyQkpYYz29nj2Pe98+hT2Ic9+St4brHP8G3vTLYQxMJKQqKMLesqJIvKjtvAcCOMCEzmVd/cB6/mTmG3QeOMf2Jpdz5l5UUVx4J9tBEQoJOyoa5PF8xCfExTD2zb7CHEtKioowbzspg6pl9efLDbTy55HPe3bCbv580iDsvHkJKj7hgD1EkaHREEcaqjtbw5roypo3tR3yXzl8AsCN0i43h7suHsfjHF3H9uP48+2kRkx9czMPvbuHgsZpgD08kKBQUYez1NTs5XhteBQA7SnpSVx6YPpp37p7M5GGpPLJoKxf++gOe/riIYzX6zm6JLAqKMJbn8xcAHNU//AoAdpShfRL4/d+P57U7zyMnPZF/f2Mjlzz0AS8u/4LqWtWPksigoAhTm8oOsDbMCwB2pDEDevLn287mhdvOJjUxnvteXsfFD33Anz/bwfFaHWFIeFNQhKk8XzGx0VFcH+YFADvaeUN78+oPzuXZWybQJzGOn726ngsf/IDnPt2uU1ISthQUYeh4bR2vrirl8pFp9IqQAoAdycy4aHgfXr7jXP5869kMSO7Kvy3YwOQHF/P0x0UcrVZgSHjR5bFh6L2Ne9h3pEaT2O3MzDg/uzfnDU1h6bYKHl20lX9/YyO/e38rN52TyU3nDNJltRIWFBRhKM9XTL+keM4f2jvYQ4kIZsa5Q3pz7pDe5G+v5MkPP+eRRVt5csnnzBg/gNsuyGJQigoPSueloAgzO/cfZcnWcv7h4qFER2kSu6NNyExmQmYyhXsO8tSSbczNL+aFZTu48sx0bp88mDEDegZ7iCKnTEERZl5aUYJzMEOnnYJqaJ8EHpw+hh9fMZw/frqdP3+2g7+uK2NiVjLfOTeTK3LSiInWFKF0DiozHkbq6x0XPrSYAb268ZfvTQr2cKSBQ8drmbP8C579dDsl+47SLymev5s0iNkTB0bMNw5K6FGZ8Qj0WVEFxZVHVQAwBPWIi+G2Cwbz4U8u5g835ZKV2p1fLyxg0n8t4sfz1rC+tCrYQxRplk49hZG8fH8BwCkjVQAwVEVHGZfnpHF5Thpbdx/kuaXbeXllKfNXlDB+UC9uOmcQU0b2VW0uCSkKijBRdbSGt9bvYmbuAH3IdBLZaQn8x3Wj+MmUEcxfUcLzS7fzj3NW07NbF24Yl8HsiQPITksI9jBFFBThYoEKAHZaSV27cOv5WdxybiZLt1Xw4vIveP6z7TzzSRG5g3oxe+JArhqVTtdY/QIgwaHJ7DBxzW8/prbe8eYPz1dtpzBQceg4L68s5cXlX7Bt72ES4mO4YVx/Zk4YwMh+KvIobSPQyWwdUYSBjTsPsK60in+7JkchESZSesTxvcmDue2CLJYXVfLi8i94Mb+Y55buYETfBL55VgbTxvajT2J8sIcqEUBBEQZOFAC8bqwKAIYbM+PswSmcPTiFnx+p5vW1Zby0ooRfvbmJ/3prE5OHpXLDWRlckZOmuSlpNwqKTu54bR2vrlYBwEjQs1ss3540iG9PGsTn5Yd4eWUJr6ws5YcvriIhLoZvjE7nhrMyyB3UiyjdlS9tSEHRyb27cTf7j9RwoyaxI8qQ1B78ZMoI/uny4XxWVMFLK0pZsGYnc/L9db6uHtOPq0enM6p/kk5HymnTZHYnd9Mzy/l8zyGW/PRi1XaKcEeqa3lnw25eX7OTJVvLqalzDErpxjWj+3H1mHSGpyUoNOQrNJkdAUr3H+WjreX8wyXZCgmhW2wM143rz3Xj+lN1pIaFG3bx+tqdPP5BIb9bXEh2nx5cPbof14xJZ3Bqj2APVzoRBUUndrIA4PiMYA9FQkxSty7MnDCAmRMGsPfQcd5aV8bra8t4+L0tPPzeFkb0TWDKyL5MGdmXM9J1pCEt06mnTqq+3jH514sZlNKNF25TAUAJTFnVUd5ct4uF63eRv6MS52BgcjemjExjysi+nDVQE+GRRKeewtxn2yoo2XeUn0wZHuyhSCeSntSVW8/P4tbzsyg/eJz3Nu1m4YZdPPvpdv7wURGpCXFcnpPG1JF9mTQ4hdgY1Q0VBUWnNddXTKIKAMppSE2IY/bEgcyeOJADx2pYvHkPCzfs4tVVpfxl2RckxMcweVgql47ow0XD+6gcegQLKCjMbCrwCBAN/K9z7v5Gr8cBfwLGAxXAjc657d5r9wG3AnXAD51zC1tap/lPlv4HMMNb5vfOuUdPbzPDS9URfwHAWRNUAFDaRmJ8F6aN7c+0sf05VlPHx1v38s7GXSwuKOeva8swg3EDenLJiD5cMiJN8xoRptWgMLNo4DHgcqAEyDezBc65jQ263Qrsc84NNbNZwAPAjWaWA8wCRgL9gPfMbJi3THPr/A4wABjhnKs3sz5tsaHhZMGaUqpVAFDaSXyXaC7LSeOynDTq6x3rd1bx/uY9vL95Dw+9s4WH3tlCelI8F4/owyXD+3De0N4qWBjmAjmimAgUOue2AZjZHGAa0DAopgE/9x7PB37nHRlMA+Y4544DRWZW6K2PFtZ5B/At51w9gHNuz9ffvPA011dMTnoiZ/ZXcThpX1FRxuiMnozO6MmPLhvGngPH+KCgnEWbd/Oad4oqLiaKiVnJTM5O5YJhvXW/RhgKJCj6A8UNnpcAZzfXxzlXa2ZVQIrX/lmjZU8UJGpunUPwH41cD5TjP121tfGgzOx24HaAgQMHBrAZ4WHDzirWlx7g59fkBHsoEoH6JMafvOz2eG0dy4sqWby5nI+2lvOrNzfBm/65jwuyezM5O5XzhvYmNSEu2MOW0xSKk9lxwDHnXK6Z3QA8A1zQuJNz7ingKfBfHtuxQwyeeb4SfwHAcSoAKMEVFxPNBdmpXJCdCvgvvf1o614+2rqXxZv38PLKUgBy0hO5YJg/OHIzexEXo9NUnU0gQVGKf87ghAyvrak+JWYWAyThn9Ruadnm2kuAl73HrwB/DGCMEeFYTR2vrCrlipFp9OymK1AktKQndWVm7gBm5g6gvt6xYecBlmwtZ8mWcp75uIgnP9xGfJcocgclc86QFCYNTmF0RhJdonUJbqgLJCjygWwzy8L/YT4L+FajPguAm4GlwHTgfeecM7MFwF/M7Df4J7OzgeWAtbDOV4GLgSLgQmDL1966MPPuxt1UHa3hxgmaxJbQFhVljMpIYlRGEndePJTDx2tZVlTBki17+WxbBb9eWABAt9hoJmQmM2lwCucMSeHMfonEKDhCTqtB4c053AUsxH8p6zPOuQ1m9kvA55xbADwNPO9NVlfi/+DH65eHf5K6FrjTOVcH0NQ6vbe8H3jBzO4GDgG3td3mdm55vmL69+zKeUN6B3soIqeke1wMl4xI45IRaYD/G/yWFVXy2bYKln5ewQNvbwYgIS6GCVnJnOMFxxnpiapjFgJUwqOTKNl3hAseXMwPL8nm7suHtb6ASCdSfvC4PzS2VfDZ5xVs23sYgIT4GMYP6sWEzGRyB/VizICeuneoDamER5h5aYV/CmdGrgoASvhJTYjjmjH9uGZMPwB2VR3js20VLN9eSX5RJR8U+E9VdYk2RvVP8geHFx76wq72pyOKTuBEAcDMlO78+bbGVyaLhL99h6tZsWMf+Tsq8W3fx9qS/dTU+T+7hvbpwYTMXuQOSmZCZjIDkrvqPo4A6YgijCz1CgD+dOqIYA9FJCh6dY89ebc4+K8AXFtSRf72SnzbK3ljbRkvLvffmpXSPZaxA3oybmBPxg3sxeiMJBLiuwRz+J2egqITmJtfTFLXLlzh/ScRiXTxXaKZmJXMxKxkwH/UvWXPQfK372P1F/tZXbyPRZv9RR3MILtPDy88ejF2QE+GpSVokvwUKChCXNWRGt7esIvZKgAo0qyoKGNE30RG9E3k25MGAf7/O2tK9rPKC453Nu4mz1cCQPfYaEZlJJ0MjtEZSfRNjNcpq2YoKELca14BwBkqAChySpK6dWHysFQmD/PfOe6cY0fFEVYV+486VhXv5w9LtlFb75/r6N0jjlH9ExnVP4lRGT0Z1T+JtMQ4hQcKipCX5ytmZD8VABQ5XWZGZu/uZPbuzvXj/FcPHqupY8POA6wvrWJtSRXrS6v4cEs5XnbQu0ccozOSOLN/EqP7+28gTEuMD+JWBIeCIoSdKAD4i2tHBnsoImEpvks04wf1YvygXifbjlTXsqnsAOtKqlhb6g+PDwr2nAyP1IQ4Rvf3h0dOv0Ry0hPJ6BXeV1opKEJYXn4xsTFRTBvbL9hDEYkY3WJjGD8omfGDkk+2HamuZePOA6wrrWJdSRXrSqtY3CA8EuJjOKNvIjn9EjkjPYEz0hMZlpYQNvOKCooQdaymjldX72TKyL4qACgSZN1iY/w3+GV+NTwKdh1kU9lBNpZVsansIHm+Yo5U1wEQHWUM7t3dCw//n5z0xE5Zdl1BEaLeOVEAUJPYIiGpW2wM4wb2YtzAL09b1dc7vqg8wsayA2wqO8DGnQfIL6rktdU7T/bp3SOOM9ITGJ6WwLC+/r+z03rQLTZ0P45Dd2QRbp5XAPDcISnBHoqIBCgq6ssJ86tGpZ9s33e4mk27/MGxqewgm8oO8KeiHVTX1p/sMyC5qz880hIY3tf/9+DU7iHx/R0KihBUsu8IHxfu5R8vzSZKNwWJdHq9usdy7pDenNug8nNdvWNHxWG27D7Ilt2HKNh9kC27DvJBQfnJS3ajo4zMlG4ng+PEn8yUbh1ajl1BEYLmr/DfFDR9vAoAioSr6ChjcGoPBqf2YOqZX7ZX19ZTtPfwyeDYsvsgG3ce4K31uzhRmi82Ooqs3t0ZmtaDe6eOYEByt3Ydq4IixNTXO+b5Sjh/aG8yerXvzheR0BMbE8Xwvv7TT4z5sv1odR2Few55RyAHKdxziHUlVcTGtP+RhYIixHz6eQWl+49y75UqACgiX+rqlR0ZldHxN9/qOwdDzFyfvwDg5SoAKCIhQkERQvYfqWbhhl1cP65/2NyoIyKdn4IihLy2eqdXAFCT2CISOhQUISTPV8yZ/RMZ2U8FAEUkdCgoQsT60io27DzATN2JLSIhRkERIvJ8XgHAMf2DPRQRka9QUISAYzV1vLqqlKkj+5LUTd/tKyKhRUERAhZu2MWBY7XcOEGnnUQk9CgoQsA8XwkZvbpyzmAVABSR0KOgCLLiSn8BwBnjB6gAoIiEJAVFkM1fUYIZTNe9EyISohQUQVRX75i/wl8AsH/PrsEejohIkxQUQfTp53sp3X9Uk9giEtIUFEE0N7+Ynt1UAFBEQpuCIkj2H6nmnQ27uW5s/5D4qkMRkeYEFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTTmGdj5rZoa+5XSHv1VWlVNfVq2SHiIS8VoPCzKKBx4ArgRxgtpnlNOp2K7DPOTcUeBh4wFs2B5gFjASmAo+bWXRr6zSzXKDXaW5bSMvzlTCqfxI5/RKDPRQRkRYFckQxESh0zm1zzlUDc4BpjfpMA57zHs8HLjUz89rnOOeOO+eKgEJvfc2u0wuRXwM/Pb1NC13rS6vYWHaAmbokVkQ6gUCCoj9Q3OB5idfWZB/nXC1QBaS0sGxL67wLWOCcK2tpUGZ2u5n5zMxXXl4ewGaEjjxfMXExUVw7VgUARST0hdRktpn1A2YAv22tr3PuKedcrnMuNzU1tf0H10ZOFgA8sy9JXVUAUERCXyBBUQo0nHHN8Nqa7GNmMUASUNHCss21jwOGAoVmth3oZmaFAW5Lp3CyAKAmsUWkkwgkKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzvnHNe+yzvqqgsIBtY3tw6nXN/dc71dc5lOucygSPeBHnYyPMVMyC5K5NUAFBEOomY1jo452rN7C5gIRANPOOc22BmvwR8zrkFwNPA895v/5X4P/jx+uUBG4Fa4E7nXB1AU+ts+80LLcWVR/iksIJ7Lh+mAoAi0mm0GhQAzrk3gTcbtf1rg8fH8M8tNLXsr4BfBbLOJvr0CGR8ncU8rwDgN8fraicR6TxCajI7nNXVO+b7irkgO1UFAEWkU1FQdJBPCveys+qYJrFFpNNRUHSQub5ienXrwmU5fYI9FBGRU6Kg6AD7Dlfz7obdXDdOBQBFpPNRUHSAV1erAKCIdF4KinbmnGNufjGjM5I4I10FAEWk81FQtLP1pQfYvOsgM3Q0ISKdlIKinZ0sADimX7CHIiLytSgo2tGxmjpeXV3KlSoAKCKdmIKiHb29fhcHj9Uyc4JOO4lI56WgaEcnCwBmqQCgiHReCop28kXFET79vIKZ4weoAKCIdGoKinYyf0WxCgCKSFhQULSDunrHvBUlTM5OpZ8KAIpIJ6egaAcfF+6lrOoYN2oSW0TCgIKiHeTl+wsAXnqGCgCKSOenoGhjlYereWfjLq4fl6ECgCISFhQUbezVVaXU1DlmTtAktoiEBwVFG3LOkecrZkxGEiP6qgCgiIQHBUUbWldapQKAIhJ2FBRt6GQBwLEqACgi4UNB0UaO1dTx2uqdXDUqncR4FQAUkfChoGgjJwsA6rSTiIQZBUUbmZtfzMDkbpydlRzsoYiItCkFRRvYUXGYpdsqmJmboQKAIhJ2FBRtYP6KEqJUAFBEwpSC4jTV1Tvmryhh8rBU0pNUAFBEwo+C4jR9tLWcsqpjmsQWkbCloDhNeb5ikrvHctkZacEeiohIu1BQnIbKw9W8u3E314/rT2yM/ilFJDwF9OlmZlPNrMDMCs3s3iZejzOzud7ry8wss8Fr93ntBWY2pbV1mtkLXvt6M3vGzEL27rVXThQA1GknEQljrQaFmUUDjwFXAjnAbDPLadTtVmCfc24o8DDwgLdsDjALGAlMBR43s+hW1vkCMAIYBXQFbjutLWwnzjnm+YoZM6Anw/smBHs4IiLtJpAjiolAoXNum3OuGpgDTGvUZxrwnPd4PnCpmZnXPsc5d9w5VwQUeutrdp3OuTedB1gOhOQ1p2tL/AUAZ+aG5PBERNpMIEHRHyhu8LzEa2uyj3OuFqgCUlpYttV1eqecvg283dSgzOx2M/OZma+8vDyAzWhbeb5i4rtEcc0YFQAUkfAWyjOwjwNLnHMfNfWic+4p51yucy43NTW1Qwd2tLqOBat3ctWZKgAoIuEvJoA+pUDD2doMr62pPiVmFgMkARWtLNvsOs3s34BU4P8EML4O9/aGMg4er2XmBE1ii0j4C+SIIh/INrMsM4vFPzm9oFGfBcDN3uPpwPveHMMCYJZ3VVQWkI1/3qHZdZrZbcAUYLZzrv70Nq99zM0vZlCKCgCKSGRo9YjCOVdrZncBC4Fo4Bnn3AYz+yXgc84tAJ4GnjezQqAS/wc/Xr88YCNQC9zpnKsDaGqd3ls+AewAlvrnw3nZOffLNtvi07Sj4jCfbavkJ1OG441PRCSsBXLqCefcm8Cbjdr+tcHjY8CMZpb9FfCrQNbptQc0pmCZ5/MKAJ6lq51EJDKE8mR2yDlRAPDCYan0TYoP9nBERDqEguIULNlazq4DKgAoIpFFQXEK8vKLSekey6UqACgiEURBEaCKQ8d5b5MKAIpI5NEnXoBOFgDUvRMiEmEUFAFwzpHnK2bsgJ4MS1MBQBGJLAqKAKwpqWLL7kOaxBaRiKSgCMCXBQDTgz0UEZEOp6BoxdHqOl5fvZOrRqWToAKAIhKBFBSteGu9vwDgjTrtJCIRSkHRirn5xWSmdGOiCgCKSIRSULRg+97DLCuqZEbuABUAFJGIpaBowbwVxSoAKCIRT0HRjNq6euavKOGi4X1UAFBEIpqCohkfbd3L7gPHmZmrowkRiWwKimbM9QoAXjJCBQBFJLIpKJqgAoAiIl/Sp2ATXllVSm2940YVABQRUVA05pxjbn4x4wb2JFsFAEVEFBSNrS7ez9Y9KgAoInKCgqKRPF8JXbtEc/VoFQAUEQEFxVccqa7l9TUqACgi0pCCooG31u3i0PFaTWKLiDSgoGhgrq+YrN7dmZDZK9hDEREJGQoKT9HewywvqmRGboYKAIqINKCg8MzzqQCgiEhTFBT4CwC+tLKEi4f3IS1RBQBFRBpSUABLtpaz+8BxZujeCRGRv6GgwF8AsHePWC49o0+whyIiEnIiPij2HjrOok17uH5cf7pER/w/h4jI34j4T8ZXVqoAoIhISwIKCjObamYFZlZoZvc28Xqcmc31Xl9mZpkNXrvPay8wsymtrdPMsrx1FHrrjD3NbWyWc448XzFnDezJ0D4qACgi0pRWg8LMooHHgCuBHGC2meU06nYrsM85NxR4GHjAWzYHmAWMBKYCj5tZdCvrfAB42FvXPm/d7WKVCgCKiLQqkCOKiUChc26bc64amANMa9RnGvCc93g+cKn571qbBsxxzh13zhUBhd76mlynt8wl3jrw1nnd1966VszzFfsLAI7p115vISLS6QUSFP2B4gbPS7y2Jvs452qBKiClhWWba08B9nvraO69ADCz283MZ2a+8vLyADbjbw1M7s53zsukR1zM11peRCQSdNpPSOfcU8BTALm5ue7rrOOOi4a06ZhERMJRIEcUpUDDk/gZXluTfcwsBkgCKlpYtrn2CqCnt47m3ktERDpQIEGRD2R7VyPF4p+cXtCozwLgZu/xdOB955zz2md5V0VlAdnA8ubW6S2z2FsH3jpf+/qbJyIip6vVU0/OuVozuwtYCEQDzzjnNpjZLwGfc24B8DTwvJkVApX4P/jx+uUBG4Fa4E7nXB1AU+v03vKfgTlm9h/AKm/dIiISJOb/Jb5zy83NdT6fL9jDEBHpVMxshXMut7V+EX9ntoiItExBISIiLVJQiIhIixQUIiLSorCYzDazcmDH11y8N7C3DYfTGWibI4O2Ofyd7vYOcs6lttYpLILidJiZL5BZ/3CibY4M2ubw11Hbq1NPIiLSIgWFiIi0SEHhFRaMMNrmyKBtDn8dsr0RP0chIiIt0xGFiIi0SEEhIiItiuigMLOpZlZgZoVmdm+wx3MqzGyAmS02s41mtsHM/tFrTzazd81sq/d3L6/dzOxRb1vXmtlZDdZ1s9d/q5nd3KB9vJmt85Z51Puq2qDzvnd9lZm94T3PMrNl3jjneqXr8crbz/Xal5lZZoN13Oe1F5jZlAbtIfczYWY9zWy+mW02s01mdk6472czu9v7uV5vZi+aWXy47Wcze8bM9pjZ+gZt7b5fm3uPFjnnIvIP/vLmnwODgVhgDZAT7HGdwvjTgbO8xwnAFiAHeBC412u/F3jAe3wV8BZgwCRgmdeeDGzz/u7lPe7lvbbc62veslcGe7u9cd0D/AV4w3ueB8zyHj8B3OE9/gHwhPd4FjDXe5zj7e84IMv7OYgO1Z8J/N8df5v3OBboGc77Gf/XHxcBXRvs3++E234GJgNnAesbtLX7fm3uPVoca7D/EwTxh/EcYGGD5/cB9wV7XKexPa8BlwMFQLrXlg4UeI+fBGY36F/gvT4beLJB+5NeWzqwuUH7V/oFcTszgEXAJcAb3n+CvUBM4/2K//tOzvEex3j9rPG+PtEvFH8m8H9bZBHehSeN91847mf8QVHsffjFePt5SjjuZyCTrwZFu+/X5t6jpT+RfOrpxA/jCSVeW6fjHWqPA5YBac65Mu+lXUCa97i57W2pvaSJ9mD7H+CnQL33PAXY75yr9Z43HOfJbfNer/L6n+q/RTBlAeXAH73Tbf9rZt0J4/3snCsFHgK+AMrw77cVhPd+PqEj9mtz79GsSA6KsGBmPYCXgB855w40fM35f2UIm+ufzexqYI9zbkWwx9KBYvCfnvi9c24ccBj/6YKTwnA/9wKm4Q/JfkB3YGpQBxUEHbFfA32PSA6KUmBAg+cZXlunYWZd8IfEC865l73m3WaW7r2eDuzx2pvb3pbaM5poD6bzgGvNbDswB//pp0eAnmZ24mt9G47z5LZ5rycBFZz6v0UwlQAlzrll3vP5+IMjnPfzZUCRc67cOVcDvIx/34fzfj6hI/Zrc+/RrEgOinwg27uSIhb/JNiCII8pYN4VDE8Dm5xzv2nw0gLgxJUPN+OfuzjRfpN39cQkoMo7/FwIXGFmvbzf5K7Af/62DDhgZpO897qpwbqCwjl3n3MuwzmXiX9/ve+c+ztgMTDd69Z4m0/8W0z3+juvfZZ3tUwWkI1/4i/kfiacc7uAYjMb7jVdiv876MN2P+M/5TTJzLp5YzqxzWG7nxvoiP3a3Hs0L5iTVsH+g/9Kgi34r4D4l2CP5xTHfj7+Q8a1wGrvz1X4z80uArYC7wHJXn8DHvO2dR2Q22Bd3wUKvT+3NGjPBdZ7y/yORhOqQd7+i/jyqqfB+D8ACoF5QJzXHu89L/ReH9xg+X/xtquABlf5hOLPBDAW8Hn7+lX8V7eE9X4GfgFs9sb1PP4rl8JqPwMv4p+DqcF/5HhrR+zX5t6jpT8q4SEiIi2K5FNPIiISAAWFiIi0SEEhIiItUlCIiEiLFBQiItIiBYWIiLRIQSEiIi36/zob5nVzA95IAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "lrs=[]\n", - "for i in range(100000):\n", - " sc.step()\n", - " lrs.append(sc.get_lr())\n", - "xs = list(range(100000))\n", - "plt.plot(xs, lrs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e613fe16", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f0fd9f40", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/audio_feature.ipynb b/.notebook/audio_feature.ipynb deleted file mode 100644 index 04b4a3924..000000000 --- a/.notebook/audio_feature.ipynb +++ /dev/null @@ -1,1207 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 94, - "id": "matched-camera", - "metadata": {}, - "outputs": [], - "source": [ - "from nnAudio import Spectrogram\n", - "from scipy.io import wavfile\n", - "import torch\n", - "import soundfile as sf\n", - "import numpy as np" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "id": "quarterly-solution", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[43 75 69 ... 7 6 3]\n", - "[43 75 69 ... 7 6 3]\n", - "[43 75 69 ... 7 6 3]\n" - ] - } - ], - "source": [ - "import scipy.io.wavfile as wav\n", - "\n", - "rate,sig = wav.read('./BAC009S0764W0124.wav')\n", - "sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n", - "sample, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", - "print(sig)\n", - "print(song)\n", - "print(sample)" - ] - }, - { - "cell_type": "code", - "execution_count": 96, - "id": "middle-salem", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "16000\n", - "[43 75 69 ... 7 6 3]\n", - "(83792,)\n", - "int16\n", - "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.2733 seconds\n", - "tensor([[[[-4.0940e+03, 1.2600e+04],\n", - " [ 8.5108e+03, -5.4930e+03],\n", - " [-3.3631e+03, -1.7904e+03],\n", - " ...,\n", - " [ 8.2279e+03, -9.3340e+03],\n", - " [-3.1990e+03, 2.0969e+03],\n", - " [-1.2669e+03, 4.4488e+03]],\n", - "\n", - " [[ 3.4886e+03, -9.9620e+03],\n", - " [-4.5364e+03, 4.1907e+02],\n", - " [ 2.5074e+03, 7.1339e+03],\n", - " ...,\n", - " [-5.4819e+03, 3.9258e+01],\n", - " [ 4.7221e+03, 6.5887e+01],\n", - " [ 9.6492e+02, -3.4386e+03]],\n", - "\n", - " [[-3.4947e+03, 9.2981e+03],\n", - " [-7.5164e+03, 8.1856e+02],\n", - " [-5.3766e+03, -9.0889e+03],\n", - " ...,\n", - " [ 1.4317e+03, 5.7447e+03],\n", - " [-3.1178e+03, 3.0740e+03],\n", - " [-3.4351e+03, 5.6900e+02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 6.7112e+01, -4.5737e+00],\n", - " [-9.6295e+00, 3.5554e+01],\n", - " [ 1.8527e+00, -1.0491e+01],\n", - " ...,\n", - " [-1.1157e+01, 3.4423e+00],\n", - " [ 3.1193e+00, -4.4388e+00],\n", - " [-8.8242e+00, 8.0324e+00]],\n", - "\n", - " [[-6.5080e+01, 2.9543e+00],\n", - " [ 3.9992e+01, -1.3836e+01],\n", - " [-9.2803e+00, 1.0318e+01],\n", - " ...,\n", - " [ 4.2928e+00, 9.2397e+00],\n", - " [ 3.6642e+00, 9.4680e+00],\n", - " [ 4.8932e+00, -2.5199e+01]],\n", - "\n", - " [[ 4.7264e+01, -1.0721e+00],\n", - " [-6.0516e+00, -1.4589e+01],\n", - " [ 1.3127e+01, 1.4995e+00],\n", - " ...,\n", - " [ 1.7333e+01, -1.4380e+01],\n", - " [-3.6046e+00, -6.1019e+00],\n", - " [ 1.3321e+01, 2.3184e+01]]]])\n" - ] - } - ], - "source": [ - "sr, song = wavfile.read('./BAC009S0764W0124.wav') # Loading your audio\n", - "print(sr)\n", - "print(song)\n", - "print(song.shape)\n", - "print(song.dtype)\n", - "x = song\n", - "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", - "\n", - "spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n", - " window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n", - " fmin=50,fmax=8000, sr=sr) # Initializing the model\n", - "\n", - "spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", - "print(spec)" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "id": "finished-sterling", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "16000\n", - "[43 75 69 ... 7 6 3]\n", - "(83792,)\n", - "int16\n", - "True\n", - "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.2001 seconds\n", - "torch.Size([1, 1025, 164, 2])\n", - "tensor([[[[-4.0940e+03, 1.2600e+04],\n", - " [ 8.5108e+03, -5.4930e+03],\n", - " [-3.3631e+03, -1.7904e+03],\n", - " ...,\n", - " [ 8.2279e+03, -9.3340e+03],\n", - " [-3.1990e+03, 2.0969e+03],\n", - " [-1.2669e+03, 4.4488e+03]],\n", - "\n", - " [[ 3.4886e+03, -9.9620e+03],\n", - " [-4.5364e+03, 4.1907e+02],\n", - " [ 2.5074e+03, 7.1339e+03],\n", - " ...,\n", - " [-5.4819e+03, 3.9258e+01],\n", - " [ 4.7221e+03, 6.5887e+01],\n", - " [ 9.6492e+02, -3.4386e+03]],\n", - "\n", - " [[-3.4947e+03, 9.2981e+03],\n", - " [-7.5164e+03, 8.1856e+02],\n", - " [-5.3766e+03, -9.0889e+03],\n", - " ...,\n", - " [ 1.4317e+03, 5.7447e+03],\n", - " [-3.1178e+03, 3.0740e+03],\n", - " [-3.4351e+03, 5.6900e+02]],\n", - "\n", - " ...,\n", - "\n", - " [[ 6.7112e+01, -4.5737e+00],\n", - " [-9.6295e+00, 3.5554e+01],\n", - " [ 1.8527e+00, -1.0491e+01],\n", - " ...,\n", - " [-1.1157e+01, 3.4423e+00],\n", - " [ 3.1193e+00, -4.4388e+00],\n", - " [-8.8242e+00, 8.0324e+00]],\n", - "\n", - " [[-6.5080e+01, 2.9543e+00],\n", - " [ 3.9992e+01, -1.3836e+01],\n", - " [-9.2803e+00, 1.0318e+01],\n", - " ...,\n", - " [ 4.2928e+00, 9.2397e+00],\n", - " [ 3.6642e+00, 9.4680e+00],\n", - " [ 4.8932e+00, -2.5199e+01]],\n", - "\n", - " [[ 4.7264e+01, -1.0721e+00],\n", - " [-6.0516e+00, -1.4589e+01],\n", - " [ 1.3127e+01, 1.4995e+00],\n", - " ...,\n", - " [ 1.7333e+01, -1.4380e+01],\n", - " [-3.6046e+00, -6.1019e+00],\n", - " [ 1.3321e+01, 2.3184e+01]]]])\n", - "True\n" - ] - } - ], - "source": [ - "wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", - "print(sr)\n", - "print(wav)\n", - "print(wav.shape)\n", - "print(wav.dtype)\n", - "print(np.allclose(wav, song))\n", - "\n", - "x = wav\n", - "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", - "\n", - "spec_layer = Spectrogram.STFT(n_fft=2048, freq_bins=None, hop_length=512,\n", - " window='hann', freq_scale='linear', center=True, pad_mode='reflect',\n", - " fmin=50,fmax=8000, sr=sr) # Initializing the model\n", - "\n", - "wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", - "print(wav_spec.shape)\n", - "print(wav_spec)\n", - "print(np.allclose(wav_spec, spec))" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "id": "running-technology", - "metadata": {}, - "outputs": [], - "source": [ - "import decimal\n", - "\n", - "import numpy\n", - "import math\n", - "import logging\n", - "def round_half_up(number):\n", - " return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))\n", - "\n", - "\n", - "def rolling_window(a, window, step=1):\n", - " # http://ellisvalentiner.com/post/2017-03-21-np-strides-trick\n", - " shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)\n", - " strides = a.strides + (a.strides[-1],)\n", - " return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]\n", - "\n", - "\n", - "def framesig(sig, frame_len, frame_step, dither=1.0, preemph=0.97, remove_dc_offset=True, wintype='hamming', stride_trick=True):\n", - " \"\"\"Frame a signal into overlapping frames.\n", - "\n", - " :param sig: the audio signal to frame.\n", - " :param frame_len: length of each frame measured in samples.\n", - " :param frame_step: number of samples after the start of the previous frame that the next frame should begin.\n", - " :param winfunc: the analysis window to apply to each frame. By default no window is applied.\n", - " :param stride_trick: use stride trick to compute the rolling window and window multiplication faster\n", - " :returns: an array of frames. Size is NUMFRAMES by frame_len.\n", - " \"\"\"\n", - " slen = len(sig)\n", - " frame_len = int(round_half_up(frame_len))\n", - " frame_step = int(round_half_up(frame_step))\n", - " if slen <= frame_len:\n", - " numframes = 1\n", - " else:\n", - " numframes = 1 + (( slen - frame_len) // frame_step)\n", - "\n", - " # check kaldi/src/feat/feature-window.h\n", - " padsignal = sig[:(numframes-1)*frame_step+frame_len]\n", - " if wintype is 'povey':\n", - " win = numpy.empty(frame_len)\n", - " for i in range(frame_len):\n", - " win[i] = (0.5-0.5*numpy.cos(2*numpy.pi/(frame_len-1)*i))**0.85 \n", - " else: # the hamming window\n", - " win = numpy.hamming(frame_len)\n", - " \n", - " if stride_trick:\n", - " frames = rolling_window(padsignal, window=frame_len, step=frame_step)\n", - " else:\n", - " indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(\n", - " numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T\n", - " indices = numpy.array(indices, dtype=numpy.int32)\n", - " frames = padsignal[indices]\n", - " win = numpy.tile(win, (numframes, 1))\n", - " \n", - " frames = frames.astype(numpy.float32)\n", - " raw_frames = numpy.zeros(frames.shape)\n", - " for frm in range(frames.shape[0]):\n", - " raw_frames[frm,:] = frames[frm,:]\n", - " frames[frm,:] = do_dither(frames[frm,:], dither) # dither\n", - " frames[frm,:] = do_remove_dc_offset(frames[frm,:]) # remove dc offset\n", - " # raw_frames[frm,:] = frames[frm,:]\n", - " frames[frm,:] = do_preemphasis(frames[frm,:], preemph) # preemphasize\n", - "\n", - " return frames * win, raw_frames\n", - "\n", - "\n", - "def magspec(frames, NFFT):\n", - " \"\"\"Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n", - "\n", - " :param frames: the array of frames. Each row is a frame.\n", - " :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n", - " :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.\n", - " \"\"\"\n", - " if numpy.shape(frames)[1] > NFFT:\n", - " logging.warn(\n", - " 'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',\n", - " numpy.shape(frames)[1], NFFT)\n", - " complex_spec = numpy.fft.rfft(frames, NFFT)\n", - " return numpy.absolute(complex_spec)\n", - "\n", - "\n", - "def powspec(frames, NFFT):\n", - " \"\"\"Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).\n", - "\n", - " :param frames: the array of frames. Each row is a frame.\n", - " :param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.\n", - " :returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.\n", - " \"\"\"\n", - " return numpy.square(magspec(frames, NFFT))\n", - "\n", - "\n", - "def do_dither(signal, dither_value=1.0):\n", - " signal += numpy.random.normal(size=signal.shape) * dither_value\n", - " return signal\n", - " \n", - "def do_remove_dc_offset(signal):\n", - " signal -= numpy.mean(signal)\n", - " return signal\n", - "\n", - "def do_preemphasis(signal, coeff=0.97):\n", - " \"\"\"perform preemphasis on the input signal.\n", - "\n", - " :param signal: The signal to filter.\n", - " :param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.\n", - " :returns: the filtered signal.\n", - " \"\"\"\n", - " return numpy.append((1-coeff)*signal[0], signal[1:] - coeff * signal[:-1])" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "id": "ignored-retreat", - "metadata": {}, - "outputs": [], - "source": [ - "def fbank(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n", - " wintype='hamming'):\n", - " highfreq= highfreq or samplerate/2\n", - " frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n", - " spec = magspec(frames, nfft) # nearly the same until this part\n", - " rspec = magspec(raw_frames, nfft)\n", - " return spec, rspec\n", - "\n", - "\n", - "\n", - "def frames(signal,samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40,nfft=512,lowfreq=0,highfreq=None,dither=1.0,remove_dc_offset=True, preemph=0.97, \n", - " wintype='hamming'):\n", - " highfreq= highfreq or samplerate/2\n", - " frames, raw_frames = framesig(signal, winlen*samplerate, winstep*samplerate, dither, preemph, remove_dc_offset, wintype)\n", - " return raw_frames" - ] - }, - { - "cell_type": "code", - "execution_count": 100, - "id": "federal-teacher", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "import torch\n", - "import torch.nn as nn\n", - "from torch.nn.functional import conv1d, conv2d, fold\n", - "import scipy # used only in CFP\n", - "\n", - "import numpy as np\n", - "from time import time\n", - "\n", - "def pad_center(data, size, axis=-1, **kwargs):\n", - "\n", - " kwargs.setdefault('mode', 'constant')\n", - "\n", - " n = data.shape[axis]\n", - "\n", - " lpad = int((size - n) // 2)\n", - "\n", - " lengths = [(0, 0)] * data.ndim\n", - " lengths[axis] = (lpad, int(size - n - lpad))\n", - "\n", - " if lpad < 0:\n", - " raise ParameterError(('Target size ({:d}) must be '\n", - " 'at least input size ({:d})').format(size, n))\n", - "\n", - " return np.pad(data, lengths, **kwargs)\n", - "\n", - "\n", - "\n", - "sz_float = 4 # size of a float\n", - "epsilon = 10e-8 # fudge factor for normalization\n", - "\n", - "def create_fourier_kernels(n_fft, win_length=None, freq_bins=None, fmin=50,fmax=6000, sr=44100,\n", - " freq_scale='linear', window='hann', verbose=True):\n", - "\n", - " if freq_bins==None: freq_bins = n_fft//2+1\n", - " if win_length==None: win_length = n_fft\n", - "\n", - " s = np.arange(0, n_fft, 1.)\n", - " wsin = np.empty((freq_bins,1,n_fft))\n", - " wcos = np.empty((freq_bins,1,n_fft))\n", - " start_freq = fmin\n", - " end_freq = fmax\n", - " bins2freq = []\n", - " binslist = []\n", - "\n", - " # num_cycles = start_freq*d/44000.\n", - " # scaling_ind = np.log(end_freq/start_freq)/k\n", - "\n", - " # Choosing window shape\n", - "\n", - " #window_mask = get_window(window, int(win_length), fftbins=True)\n", - " window_mask = np.hamming(int(win_length))\n", - " window_mask = pad_center(window_mask, n_fft)\n", - "\n", - " if freq_scale == 'linear':\n", - " if verbose==True:\n", - " print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n", - " f\"get a valid freq range\")\n", - " \n", - " start_bin = start_freq*n_fft/sr\n", - " scaling_ind = (end_freq-start_freq)*(n_fft/sr)/freq_bins\n", - "\n", - " for k in range(freq_bins): # Only half of the bins contain useful info\n", - " # print(\"linear freq = {}\".format((k*scaling_ind+start_bin)*sr/n_fft))\n", - " bins2freq.append((k*scaling_ind+start_bin)*sr/n_fft)\n", - " binslist.append((k*scaling_ind+start_bin))\n", - " wsin[k,0,:] = np.sin(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*(k*scaling_ind+start_bin)*s/n_fft)\n", - "\n", - " elif freq_scale == 'log':\n", - " if verbose==True:\n", - " print(f\"sampling rate = {sr}. Please make sure the sampling rate is correct in order to\"\n", - " f\"get a valid freq range\")\n", - " start_bin = start_freq*n_fft/sr\n", - " scaling_ind = np.log(end_freq/start_freq)/freq_bins\n", - "\n", - " for k in range(freq_bins): # Only half of the bins contain useful info\n", - " # print(\"log freq = {}\".format(np.exp(k*scaling_ind)*start_bin*sr/n_fft))\n", - " bins2freq.append(np.exp(k*scaling_ind)*start_bin*sr/n_fft)\n", - " binslist.append((np.exp(k*scaling_ind)*start_bin))\n", - " wsin[k,0,:] = np.sin(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*(np.exp(k*scaling_ind)*start_bin)*s/n_fft)\n", - "\n", - " elif freq_scale == 'no':\n", - " for k in range(freq_bins): # Only half of the bins contain useful info\n", - " bins2freq.append(k*sr/n_fft)\n", - " binslist.append(k)\n", - " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n", - " else:\n", - " print(\"Please select the correct frequency scale, 'linear' or 'log'\")\n", - " return wsin.astype(np.float32),wcos.astype(np.float32), bins2freq, binslist, window_mask.astype(np.float32)\n", - "\n", - "\n", - "\n", - "def broadcast_dim(x):\n", - " \"\"\"\n", - " Auto broadcast input so that it can fits into a Conv1d\n", - " \"\"\"\n", - "\n", - " if x.dim() == 2:\n", - " x = x[:, None, :]\n", - " elif x.dim() == 1:\n", - " # If nn.DataParallel is used, this broadcast doesn't work\n", - " x = x[None, None, :]\n", - " elif x.dim() == 3:\n", - " pass\n", - " else:\n", - " raise ValueError(\"Only support input with shape = (batch, len) or shape = (len)\")\n", - " return x\n", - "\n", - "\n", - "\n", - "### --------------------------- Spectrogram Classes ---------------------------###\n", - "class STFT(torch.nn.Module):\n", - "\n", - " def __init__(self, n_fft=2048, win_length=None, freq_bins=None, hop_length=None, window='hann',\n", - " freq_scale='no', center=True, pad_mode='reflect', iSTFT=False,\n", - " fmin=50, fmax=6000, sr=22050, trainable=False,\n", - " output_format=\"Complex\", verbose=True):\n", - "\n", - " super().__init__()\n", - "\n", - " # Trying to make the default setting same as librosa\n", - " if win_length==None: win_length = n_fft\n", - " if hop_length==None: hop_length = int(win_length // 4)\n", - "\n", - " self.output_format = output_format\n", - " self.trainable = trainable\n", - " self.stride = hop_length\n", - " self.center = center\n", - " self.pad_mode = pad_mode\n", - " self.n_fft = n_fft\n", - " self.freq_bins = freq_bins\n", - " self.trainable = trainable\n", - " self.pad_amount = self.n_fft // 2\n", - " self.window = window\n", - " self.win_length = win_length\n", - " self.iSTFT = iSTFT\n", - " self.trainable = trainable\n", - " start = time()\n", - "\n", - "\n", - "\n", - " # Create filter windows for stft\n", - " kernel_sin, kernel_cos, self.bins2freq, self.bin_list, window_mask = create_fourier_kernels(n_fft,\n", - " win_length=win_length,\n", - " freq_bins=freq_bins,\n", - " window=window,\n", - " freq_scale=freq_scale,\n", - " fmin=fmin,\n", - " fmax=fmax,\n", - " sr=sr,\n", - " verbose=verbose)\n", - "\n", - "\n", - " kernel_sin = torch.tensor(kernel_sin, dtype=torch.float)\n", - " kernel_cos = torch.tensor(kernel_cos, dtype=torch.float)\n", - " \n", - " # In this way, the inverse kernel and the forward kernel do not share the same memory...\n", - " kernel_sin_inv = torch.cat((kernel_sin, -kernel_sin[1:-1].flip(0)), 0)\n", - " kernel_cos_inv = torch.cat((kernel_cos, kernel_cos[1:-1].flip(0)), 0)\n", - " \n", - " if iSTFT:\n", - " self.register_buffer('kernel_sin_inv', kernel_sin_inv.unsqueeze(-1))\n", - " self.register_buffer('kernel_cos_inv', kernel_cos_inv.unsqueeze(-1))\n", - "\n", - " # Applying window functions to the Fourier kernels\n", - " if window:\n", - " window_mask = torch.tensor(window_mask)\n", - " wsin = kernel_sin * window_mask\n", - " wcos = kernel_cos * window_mask\n", - " else:\n", - " wsin = kernel_sin\n", - " wcos = kernel_cos\n", - " \n", - " if self.trainable==False:\n", - " self.register_buffer('wsin', wsin)\n", - " self.register_buffer('wcos', wcos) \n", - " \n", - " if self.trainable==True:\n", - " wsin = torch.nn.Parameter(wsin, requires_grad=self.trainable)\n", - " wcos = torch.nn.Parameter(wcos, requires_grad=self.trainable) \n", - " self.register_parameter('wsin', wsin)\n", - " self.register_parameter('wcos', wcos) \n", - " \n", - " # Prepare the shape of window mask so that it can be used later in inverse\n", - " # self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))\n", - " \n", - " if verbose==True:\n", - " print(\"STFT kernels created, time used = {:.4f} seconds\".format(time()-start))\n", - " else:\n", - " pass\n", - "\n", - " def forward(self, x, output_format=None):\n", - " \"\"\"\n", - " Convert a batch of waveforms to spectrograms.\n", - " \n", - " Parameters\n", - " ----------\n", - " x : torch tensor\n", - " Input signal should be in either of the following shapes.\\n\n", - " 1. ``(len_audio)``\\n\n", - " 2. ``(num_audio, len_audio)``\\n\n", - " 3. ``(num_audio, 1, len_audio)``\n", - " It will be automatically broadcast to the right shape\n", - " \n", - " output_format : str\n", - " Control the type of spectrogram to be return. Can be either ``Magnitude`` or ``Complex`` or ``Phase``.\n", - " Default value is ``Complex``. \n", - " \n", - " \"\"\"\n", - " output_format = output_format or self.output_format\n", - " self.num_samples = x.shape[-1]\n", - " \n", - " x = broadcast_dim(x)\n", - " if self.center:\n", - " if self.pad_mode == 'constant':\n", - " padding = nn.ConstantPad1d(self.pad_amount, 0)\n", - "\n", - " elif self.pad_mode == 'reflect':\n", - " if self.num_samples < self.pad_amount:\n", - " raise AssertionError(\"Signal length shorter than reflect padding length (n_fft // 2).\")\n", - " padding = nn.ReflectionPad1d(self.pad_amount)\n", - "\n", - " x = padding(x)\n", - " spec_imag = conv1d(x, self.wsin, stride=self.stride)\n", - " spec_real = conv1d(x, self.wcos, stride=self.stride) # Doing STFT by using conv1d\n", - "\n", - " # remove redundant parts\n", - " spec_real = spec_real[:, :self.freq_bins, :]\n", - " spec_imag = spec_imag[:, :self.freq_bins, :]\n", - "\n", - " if output_format=='Magnitude':\n", - " spec = spec_real.pow(2) + spec_imag.pow(2)\n", - " if self.trainable==True:\n", - " return torch.sqrt(spec+1e-8) # prevent Nan gradient when sqrt(0) due to output=0\n", - " else:\n", - " return torch.sqrt(spec)\n", - "\n", - " elif output_format=='Complex':\n", - " return torch.stack((spec_real,-spec_imag), -1) # Remember the minus sign for imaginary part\n", - "\n", - " elif output_format=='Phase':\n", - " return torch.atan2(-spec_imag+0.0,spec_real) # +0.0 removes -0.0 elements, which leads to error in calculating phase\n", - "\n", - " def inverse(self, X, onesided=True, length=None, refresh_win=True):\n", - " \"\"\"\n", - " This function is same as the :func:`~nnAudio.Spectrogram.iSTFT` class, \n", - " which is to convert spectrograms back to waveforms. \n", - " It only works for the complex value spectrograms. If you have the magnitude spectrograms,\n", - " please use :func:`~nnAudio.Spectrogram.Griffin_Lim`. \n", - " \n", - " Parameters\n", - " ----------\n", - " onesided : bool\n", - " If your spectrograms only have ``n_fft//2+1`` frequency bins, please use ``onesided=True``,\n", - " else use ``onesided=False``\n", - "\n", - " length : int\n", - " To make sure the inverse STFT has the same output length of the original waveform, please\n", - " set `length` as your intended waveform length. By default, ``length=None``,\n", - " which will remove ``n_fft//2`` samples from the start and the end of the output.\n", - " \n", - " refresh_win : bool\n", - " Recalculating the window sum square. If you have an input with fixed number of timesteps,\n", - " you can increase the speed by setting ``refresh_win=False``. Else please keep ``refresh_win=True``\n", - " \n", - " \n", - " \"\"\"\n", - " if (hasattr(self, 'kernel_sin_inv') != True) or (hasattr(self, 'kernel_cos_inv') != True):\n", - " raise NameError(\"Please activate the iSTFT module by setting `iSTFT=True` if you want to use `inverse`\") \n", - " \n", - " assert X.dim()==4 , \"Inverse iSTFT only works for complex number,\" \\\n", - " \"make sure our tensor is in the shape of (batch, freq_bins, timesteps, 2).\"\\\n", - " \"\\nIf you have a magnitude spectrogram, please consider using Griffin-Lim.\"\n", - " if onesided:\n", - " X = extend_fbins(X) # extend freq\n", - "\n", - " \n", - " X_real, X_imag = X[:, :, :, 0], X[:, :, :, 1]\n", - "\n", - " # broadcast dimensions to support 2D convolution\n", - " X_real_bc = X_real.unsqueeze(1)\n", - " X_imag_bc = X_imag.unsqueeze(1)\n", - " a1 = conv2d(X_real_bc, self.kernel_cos_inv, stride=(1,1))\n", - " b2 = conv2d(X_imag_bc, self.kernel_sin_inv, stride=(1,1))\n", - " \n", - " # compute real and imag part. signal lies in the real part\n", - " real = a1 - b2\n", - " real = real.squeeze(-2)*self.window_mask\n", - "\n", - " # Normalize the amplitude with n_fft\n", - " real /= (self.n_fft)\n", - "\n", - " # Overlap and Add algorithm to connect all the frames\n", - " real = overlap_add(real, self.stride)\n", - " \n", - " # Prepare the window sumsqure for division\n", - " # Only need to create this window once to save time\n", - " # Unless the input spectrograms have different time steps\n", - " if hasattr(self, 'w_sum')==False or refresh_win==True:\n", - " self.w_sum = torch_window_sumsquare(self.window_mask.flatten(), X.shape[2], self.stride, self.n_fft).flatten()\n", - " self.nonzero_indices = (self.w_sum>1e-10) \n", - " else:\n", - " pass\n", - " real[:, self.nonzero_indices] = real[:,self.nonzero_indices].div(self.w_sum[self.nonzero_indices])\n", - " # Remove padding\n", - " if length is None: \n", - " if self.center:\n", - " real = real[:, self.pad_amount:-self.pad_amount]\n", - "\n", - " else:\n", - " if self.center:\n", - " real = real[:, self.pad_amount:self.pad_amount + length] \n", - " else:\n", - " real = real[:, :length] \n", - " \n", - " return real\n", - " \n", - " def extra_repr(self) -> str:\n", - " return 'n_fft={}, Fourier Kernel size={}, iSTFT={}, trainable={}'.format(\n", - " self.n_fft, (*self.wsin.shape,), self.iSTFT, self.trainable\n", - " ) " - ] - }, - { - "cell_type": "code", - "execution_count": 128, - "id": "unusual-baker", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "16000\n", - "(83792,)\n", - "sampling rate = 16000. Please make sure the sampling rate is correct in order toget a valid freq range\n", - "STFT kernels created, time used = 0.0153 seconds\n", - "torch.Size([521, 257])\n", - "(522, 257)\n", - "[[5.84560000e+04 2.55260664e+04 9.83611035e+03 ... 7.80710554e+00\n", - " 2.32206573e+01 1.90274487e+01]\n", - " [1.35420000e+04 3.47535000e+04 1.51204707e+04 ... 1.69094101e+02\n", - " 1.80534729e+02 1.84179596e+02]\n", - " [3.47560000e+04 2.83094609e+04 8.20204883e+03 ... 1.02080307e+02\n", - " 1.21321175e+02 1.08345497e+02]\n", - " ...\n", - " [9.36700000e+03 2.86213008e+04 1.41182402e+04 ... 1.19344498e+02\n", - " 1.25670158e+02 1.20691467e+02]\n", - " [2.87510000e+04 2.04348242e+04 8.76390625e+03 ... 9.74485092e+01\n", - " 9.01831894e+01 9.84055099e+01]\n", - " [4.45240000e+04 8.93593262e+03 4.39246826e+03 ... 6.16300154e+00\n", - " 8.94473553e+00 9.61348629e+00]]\n", - "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]\n", - " [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n", - " 6.36197377e+01 4.40000000e+01]]\n" - ] - } - ], - "source": [ - "wav, sr = sf.read('./BAC009S0764W0124.wav', dtype='int16')\n", - "print(sr)\n", - "print(wav.shape)\n", - "\n", - "x = wav\n", - "x = torch.tensor(x).float() # casting the array into a PyTorch Tensor\n", - "\n", - "spec_layer = STFT(n_fft=512, win_length=400, hop_length=160,\n", - " window='', freq_scale='linear', center=False, pad_mode='constant',\n", - " fmin=0, fmax=8000, sr=sr, output_format='Magnitude')\n", - "wav_spec = spec_layer(x) # Feed-forward your waveform to get the spectrogram\n", - "wav_spec = wav_spec[0].T\n", - "print(wav_spec.shape)\n", - "\n", - "\n", - "spec, rspec = fbank(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - " wintype='hamming')\n", - "print(spec.shape)\n", - "\n", - "print(wav_spec.numpy())\n", - "print(rspec)\n", - "# print(spec)\n", - "\n", - "# spec, rspec = fbank(wav, samplerate=16000,winlen=0.032,winstep=0.01,\n", - "# nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - "# dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - "# wintype='hamming')\n", - "# print(rspec)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "white-istanbul", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 129, - "id": "modern-rescue", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0. 0.11697778 0.41317591 0.75 0.96984631 0.96984631\n", - " 0.75 0.41317591 0.11697778 0. ]\n" - ] - }, - { - "data": { - "text/plain": [ - "array([0. , 0.0954915, 0.3454915, 0.6545085, 0.9045085, 1. ,\n", - " 0.9045085, 0.6545085, 0.3454915, 0.0954915])" - ] - }, - "execution_count": 129, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(np.hanning(10))\n", - "from scipy.signal import get_window\n", - "get_window('hann', 10, fftbins=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "professional-journalism", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 153, - "id": "involved-motion", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(522, 400)\n", - "[[ 43. 75. 69. ... 46. 46. 45.]\n", - " [ 210. 215. 216. ... -86. -89. -91.]\n", - " [ 128. 128. 128. ... -154. -151. -151.]\n", - " ...\n", - " [ -60. -61. -61. ... 112. 109. 110.]\n", - " [ 20. 22. 24. ... 91. 87. 87.]\n", - " [ 111. 107. 108. ... -6. -4. -8.]]\n", - "torch.Size([1, 1, 83792])\n", - "torch.Size([400, 1, 512])\n", - "torch.Size([1, 400, 521])\n", - "conv frame tensor([[ 43., 75., 69., ..., 46., 46., 45.],\n", - " [ 210., 215., 216., ..., -86., -89., -91.],\n", - " [ 128., 128., 128., ..., -154., -151., -151.],\n", - " ...,\n", - " [-143., -141., -142., ..., 96., 101., 101.],\n", - " [ -60., -61., -61., ..., 112., 109., 110.],\n", - " [ 20., 22., 24., ..., 91., 87., 87.]])\n", - "xx [[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n", - " 2.4064583e+01 2.2000000e+01]\n", - " [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n", - " 1.1877571e+02 1.6200000e+02]\n", - " [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n", - " 9.5781029e+01 1.4200000e+02]\n", - " ...\n", - " [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n", - " 9.1511757e+01 1.1500000e+02]\n", - " [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n", - " 7.8405365e+01 9.0000000e+01]\n", - " [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n", - " 5.1310158e+01 3.5000000e+01]]\n", - "torch.Size([521, 257])\n", - "yy [[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n", - " 9.15117270e+01 1.15000000e+02]\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]]\n", - "yy (522, 257)\n", - "[[5.8976000e+04 2.5100676e+04 8.5960371e+03 ... 2.0281837e+01\n", - " 2.4064583e+01 2.2000000e+01]\n", - " [2.9266000e+04 2.7298107e+04 4.7724253e+03 ... 6.6926659e+01\n", - " 1.1877571e+02 1.6200000e+02]\n", - " [1.9630000e+04 2.8117480e+04 5.2880322e+03 ... 2.8501144e+01\n", - " 9.5781029e+01 1.4200000e+02]\n", - " ...\n", - " [2.1113000e+04 2.3099363e+04 7.1594033e+03 ... 3.1945959e+01\n", - " 9.1511757e+01 1.1500000e+02]\n", - " [1.6772000e+04 2.1322793e+04 4.0607855e+02 ... 2.6011946e+01\n", - " 7.8405365e+01 9.0000000e+01]\n", - " [3.8693000e+04 1.3598203e+04 6.7706826e+03 ... 6.1070789e+01\n", - " 5.1310158e+01 3.5000000e+01]]\n", - "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [2.11130000e+04 2.30993602e+04 7.15940084e+03 ... 3.19459779e+01\n", - " 9.15117270e+01 1.15000000e+02]\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]]\n", - "False\n" - ] - } - ], - "source": [ - "f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - " wintype='hamming')\n", - "print(f.shape)\n", - "print(f)\n", - "\n", - "n_fft=512\n", - "freq_bins = n_fft//2+1\n", - "s = np.arange(0, n_fft, 1.)\n", - "wsin = np.empty((freq_bins,1,n_fft))\n", - "wcos = np.empty((freq_bins,1,n_fft))\n", - "for k in range(freq_bins): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)\n", - " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)\n", - "\n", - "\n", - "wsin = np.empty((n_fft,1,n_fft))\n", - "wcos = np.empty((n_fft,1,n_fft))\n", - "for k in range(n_fft): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.eye(n_fft, n_fft)[k]\n", - " wcos[k,0,:] = np.eye(n_fft, n_fft)[k]\n", - " \n", - " \n", - "wsin = np.empty((400,1,n_fft))\n", - "wcos = np.empty((400,1,n_fft))\n", - "for k in range(400): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.eye(400, n_fft)[k]\n", - " wcos[k,0,:] = np.eye(400, n_fft)[k]\n", - " \n", - "\n", - " \n", - "x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n", - "x = x[None, None, :]\n", - "print(x.size())\n", - "kernel_sin = torch.tensor(wsin, dtype=torch.float)\n", - "kernel_cos = torch.tensor(wcos, dtype=torch.float)\n", - "print(kernel_sin.size())\n", - "\n", - "from torch.nn.functional import conv1d, conv2d, fold\n", - "spec_imag = conv1d(x, kernel_sin, stride=160)\n", - "spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n", - "\n", - "print(spec_imag.size())\n", - "print(\"conv frame\", spec_imag[0].T)\n", - "# print(spec_imag[0].T[:, :400])\n", - "\n", - "# remove redundant parts\n", - "# spec_real = spec_real[:, :freq_bins, :]\n", - "# spec_imag = spec_imag[:, :freq_bins, :]\n", - "# spec = spec_real.pow(2) + spec_imag.pow(2)\n", - "# spec = torch.sqrt(spec)\n", - "# print(spec)\n", - "\n", - "\n", - "\n", - "s = np.arange(0, 512, 1.)\n", - "# s = s[::-1]\n", - "wsin = np.empty((freq_bins, 400))\n", - "wcos = np.empty((freq_bins, 400))\n", - "for k in range(freq_bins): # Only half of the bins contain useful info\n", - " wsin[k,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n", - " wcos[k,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n", - "\n", - "spec_real = torch.mm(spec_imag[0].T, torch.tensor(wcos, dtype=torch.float).T)\n", - "spec_imag = torch.mm(spec_imag[0].T, torch.tensor(wsin, dtype=torch.float).T)\n", - "\n", - "\n", - "# remove redundant parts\n", - "spec = spec_real.pow(2) + spec_imag.pow(2)\n", - "spec = torch.sqrt(spec)\n", - "\n", - "print('xx', spec.numpy())\n", - "print(spec.size())\n", - "print('yy', rspec[:521, :])\n", - "print('yy', rspec.shape)\n", - "\n", - "\n", - "x = spec.numpy()\n", - "y = rspec[:-1, :]\n", - "print(x)\n", - "print(y)\n", - "print(np.allclose(x, y))" - ] - }, - { - "cell_type": "code", - "execution_count": 160, - "id": "mathematical-traffic", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([257, 1, 400])\n", - "tensor([[[5.8976e+04, 2.9266e+04, 1.9630e+04, ..., 1.6772e+04,\n", - " 3.8693e+04, 3.1020e+04],\n", - " [2.5101e+04, 2.7298e+04, 2.8117e+04, ..., 2.1323e+04,\n", - " 1.3598e+04, 1.5920e+04],\n", - " [8.5960e+03, 4.7724e+03, 5.2880e+03, ..., 4.0608e+02,\n", - " 6.7707e+03, 4.3020e+03],\n", - " ...,\n", - " [2.0282e+01, 6.6927e+01, 2.8501e+01, ..., 2.6012e+01,\n", - " 6.1071e+01, 5.3685e+01],\n", - " [2.4065e+01, 1.1878e+02, 9.5781e+01, ..., 7.8405e+01,\n", - " 5.1310e+01, 6.3620e+01],\n", - " [2.2000e+01, 1.6200e+02, 1.4200e+02, ..., 9.0000e+01,\n", - " 3.5000e+01, 4.4000e+01]]])\n", - "[[5.8976000e+04 2.5100672e+04 8.5960391e+03 ... 2.0281828e+01\n", - " 2.4064537e+01 2.2000000e+01]\n", - " [2.9266000e+04 2.7298107e+04 4.7724243e+03 ... 6.6926659e+01\n", - " 1.1877571e+02 1.6200000e+02]\n", - " [1.9630000e+04 2.8117475e+04 5.2880312e+03 ... 2.8501148e+01\n", - " 9.5781006e+01 1.4200000e+02]\n", - " ...\n", - " [1.6772000e+04 2.1322793e+04 4.0607657e+02 ... 2.6011934e+01\n", - " 7.8405350e+01 9.0000000e+01]\n", - " [3.8693000e+04 1.3598203e+04 6.7706841e+03 ... 6.1070808e+01\n", - " 5.1310150e+01 3.5000000e+01]\n", - " [3.1020000e+04 1.5920403e+04 4.3019902e+03 ... 5.3685162e+01\n", - " 6.3619797e+01 4.4000000e+01]]\n", - "[[5.89760000e+04 2.51006729e+04 8.59603890e+03 ... 2.02818313e+01\n", - " 2.40645984e+01 2.20000000e+01]\n", - " [2.92660000e+04 2.72981079e+04 4.77242582e+03 ... 6.69265842e+01\n", - " 1.18775735e+02 1.62000000e+02]\n", - " [1.96300000e+04 2.81174834e+04 5.28803149e+03 ... 2.85011387e+01\n", - " 9.57810428e+01 1.42000000e+02]\n", - " ...\n", - " [1.67720000e+04 2.13227930e+04 4.06079895e+02 ... 2.60119790e+01\n", - " 7.84053656e+01 9.00000000e+01]\n", - " [3.86930000e+04 1.35982074e+04 6.77068420e+03 ... 6.10707909e+01\n", - " 5.13101944e+01 3.50000000e+01]\n", - " [3.10200000e+04 1.59203961e+04 4.30198496e+03 ... 5.36851600e+01\n", - " 6.36197377e+01 4.40000000e+01]]\n", - "False\n" - ] - } - ], - "source": [ - "f = frames(wav, samplerate=16000,winlen=0.025,winstep=0.01,\n", - " nfilt=40, nfft=512,lowfreq=0,highfreq=None,\n", - " dither=0.0,remove_dc_offset=False, preemph=1.0, \n", - " wintype='hamming')\n", - "\n", - "n_fft=512\n", - "freq_bins = n_fft//2+1\n", - "s = np.arange(0, n_fft, 1.)\n", - "wsin = np.empty((freq_bins,1,400))\n", - "wcos = np.empty((freq_bins,1,400)) #[Cout, Cin, kernel_size]\n", - "for k in range(freq_bins): # Only half of the bins contain useful info\n", - " wsin[k,0,:] = np.sin(2*np.pi*k*s/n_fft)[:400]\n", - " wcos[k,0,:] = np.cos(2*np.pi*k*s/n_fft)[:400]\n", - "\n", - " \n", - "x = torch.tensor(wav).float() # casting the array into a PyTorch Tensor\n", - "x = x[None, None, :] #[B, C, T]\n", - "\n", - "kernel_sin = torch.tensor(wsin, dtype=torch.float)\n", - "kernel_cos = torch.tensor(wcos, dtype=torch.float)\n", - "print(kernel_sin.size())\n", - "\n", - "from torch.nn.functional import conv1d, conv2d, fold\n", - "spec_imag = conv1d(x, kernel_sin, stride=160) #[1, Cout, T]\n", - "spec_real = conv1d(x, kernel_cos, stride=160) # Doing STFT by using conv1d\n", - "\n", - "# remove redundant parts\n", - "spec = spec_real.pow(2) + spec_imag.pow(2)\n", - "spec = torch.sqrt(spec)\n", - "print(spec)\n", - "\n", - "x = spec[0].T.numpy()\n", - "y = rspec[:, :]\n", - "print(x)\n", - "print(y)\n", - "print(np.allclose(x, y))" - ] - }, - { - "cell_type": "code", - "execution_count": 162, - "id": "olive-nicaragua", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: RuntimeWarning: divide by zero encountered in true_divide\n", - " \"\"\"Entry point for launching an IPython kernel.\n" - ] - }, - { - "data": { - "text/plain": [ - "27241" - ] - }, - "execution_count": 162, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.argmax(np.abs(x -y) / np.abs(y))" - ] - }, - { - "cell_type": "code", - "execution_count": 165, - "id": "ultimate-assault", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.0" - ] - }, - "execution_count": 165, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y[np.unravel_index(27241, y.shape)]" - ] - }, - { - "cell_type": "code", - "execution_count": 166, - "id": "institutional-stock", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4.2412265e-10" - ] - }, - "execution_count": 166, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x[np.unravel_index(27241, y.shape)]" - ] - }, - { - "cell_type": "code", - "execution_count": 167, - "id": "integrated-courage", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 167, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.allclose(y, x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "different-operation", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/compute_cmvn_loader_test.ipynb b/.notebook/compute_cmvn_loader_test.ipynb deleted file mode 100644 index 2b0a8b75f..000000000 --- a/.notebook/compute_cmvn_loader_test.ipynb +++ /dev/null @@ -1,793 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "purple-consequence", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/home/ssd5/zhanghui/DeepSpeech2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "defensive-mason", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "patient-convention", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Namespace(delta_delta=False, feat_dim=80, manifest_path='examples/aishell/s1/data/manifest.train.raw', num_samples=-1, num_workers=16, output_path='data/librispeech/mean_std.npz', sample_rate=16000, specgram_type='fbank', stride_ms=10.0, window_ms=25.0)\n" - ] - } - ], - "source": [ - "import argparse\n", - "import functools\n", - "\n", - "from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n", - "from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n", - "from deepspeech.frontend.normalizer import FeatureNormalizer\n", - "from deepspeech.utils.utility import add_arguments\n", - "from deepspeech.utils.utility import print_arguments\n", - "\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, -1, \"# of samples to for statistics.\")\n", - "add_arg('specgram_type', str,\n", - " 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc, fbank.\",\n", - " choices=['linear', 'mfcc', 'fbank'])\n", - "add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n", - "add_arg('delta_delta', bool,\n", - " False,\n", - " \"Audio feature with delta delta.\")\n", - "add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n", - "add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n", - "add_arg('sample_rate', int, 16000, \"target sample rate.\")\n", - "add_arg('manifest_path', str,\n", - " 'examples/aishell/s1/data/manifest.train.raw',\n", - " \"Filepath of manifest to compute normalizer's mean and stddev.\")\n", - "add_arg('num_workers',\n", - " default=16,\n", - " type=int,\n", - " help='num of subprocess workers for processing')\n", - "add_arg('output_path', str,\n", - " 'data/librispeech/mean_std.npz',\n", - " \"Filepath of write mean and stddev to (.npz).\")\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(args)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "enormous-currency", - "metadata": {}, - "outputs": [], - "source": [ - "import random\n", - "\n", - "import numpy as np\n", - "import paddle\n", - "from paddle.io import DataLoader\n", - "from paddle.io import Dataset\n", - "\n", - "from deepspeech.frontend.audio import AudioSegment\n", - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.frontend.utility import read_manifest\n", - "\n", - "class CollateFunc(object):\n", - " ''' Collate function for AudioDataset\n", - " '''\n", - " def __init__(self):\n", - " pass\n", - " \n", - " def __call__(self, batch):\n", - " mean_stat = None\n", - " var_stat = None\n", - " number = 0\n", - " for feat in batch:\n", - " sums = np.sum(feat, axis=1)\n", - " if mean_stat is None:\n", - " mean_stat = sums\n", - " else:\n", - " mean_stat += sums\n", - "\n", - " square_sums = np.sum(np.square(feat), axis=1)\n", - " if var_stat is None:\n", - " var_stat = square_sums\n", - " else:\n", - " var_stat += square_sums\n", - "\n", - " number += feat.shape[1]\n", - " #return paddle.to_tensor(number), paddle.to_tensor(mean_stat), paddle.to_tensor(var_stat)\n", - " return number, mean_stat, var_stat\n", - "\n", - "\n", - "class AudioDataset(Dataset):\n", - " def __init__(self, manifest_path, feature_func, num_samples=-1, rng=None):\n", - " self.feature_func = feature_func\n", - " self._rng = rng\n", - " manifest = read_manifest(manifest_path)\n", - " if num_samples == -1:\n", - " sampled_manifest = manifest\n", - " else:\n", - " sampled_manifest = self._rng.sample(manifest, num_samples)\n", - " self.items = sampled_manifest\n", - "\n", - " def __len__(self):\n", - " return len(self.items)\n", - "\n", - " def __getitem__(self, idx):\n", - " key = self.items[idx]['feat']\n", - " audioseg = AudioSegment.from_file(key)\n", - " feat = self.feature_func(audioseg) #(D, T)\n", - " return feat" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "armed-semester", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "process 1000 wavs,450739 frames\n", - "process 2000 wavs,887447 frames\n", - "process 3000 wavs,1354148 frames\n", - "process 4000 wavs,1816494 frames\n", - "process 5000 wavs,2359211 frames\n", - "process 6000 wavs,2828455 frames\n", - "process 7000 wavs,3276186 frames\n", - "process 8000 wavs,3692234 frames\n", - "process 9000 wavs,4139360 frames\n", - "process 10000 wavs,4591528 frames\n", - "process 11000 wavs,5020114 frames\n", - "process 12000 wavs,5459523 frames\n", - "process 13000 wavs,5899534 frames\n", - "process 14000 wavs,6323242 frames\n", - "process 15000 wavs,6736597 frames\n", - "process 16000 wavs,7207686 frames\n", - "process 17000 wavs,7637800 frames\n", - "process 18000 wavs,8093004 frames\n", - "process 19000 wavs,8529518 frames\n", - "process 20000 wavs,8906022 frames\n", - "process 21000 wavs,9352652 frames\n", - "process 22000 wavs,9807495 frames\n", - "process 23000 wavs,10247938 frames\n", - "process 24000 wavs,10700011 frames\n", - "process 25000 wavs,11126134 frames\n", - "process 26000 wavs,11558061 frames\n", - "process 27000 wavs,12010359 frames\n", - "process 28000 wavs,12470938 frames\n", - "process 29000 wavs,12916013 frames\n", - "process 30000 wavs,13345816 frames\n", - "process 31000 wavs,13752365 frames\n", - "process 32000 wavs,14174801 frames\n", - "process 33000 wavs,14642170 frames\n", - "process 34000 wavs,15053557 frames\n", - "process 35000 wavs,15531890 frames\n", - "process 36000 wavs,16022711 frames\n", - "process 37000 wavs,16437688 frames\n", - "process 38000 wavs,16859517 frames\n", - "process 39000 wavs,17307676 frames\n", - "process 40000 wavs,17796629 frames\n", - "process 41000 wavs,18264151 frames\n", - "process 42000 wavs,18711898 frames\n", - "process 43000 wavs,19159890 frames\n", - "process 44000 wavs,19576435 frames\n", - "process 45000 wavs,19992793 frames\n", - "process 46000 wavs,20464449 frames\n", - "process 47000 wavs,20886021 frames\n", - "process 48000 wavs,21317318 frames\n", - "process 49000 wavs,21738034 frames\n", - "process 50000 wavs,22171890 frames\n", - "process 51000 wavs,22622238 frames\n", - "process 52000 wavs,23100734 frames\n", - "process 53000 wavs,23526901 frames\n", - "process 54000 wavs,23969746 frames\n", - "process 55000 wavs,24418691 frames\n", - "process 56000 wavs,24862546 frames\n", - "process 57000 wavs,25336448 frames\n", - "process 58000 wavs,25778435 frames\n", - "process 59000 wavs,26216199 frames\n", - "process 60000 wavs,26694692 frames\n", - "process 61000 wavs,27148978 frames\n", - "process 62000 wavs,27617088 frames\n", - "process 63000 wavs,28064946 frames\n", - "process 64000 wavs,28519843 frames\n", - "process 65000 wavs,28989722 frames\n", - "process 66000 wavs,29470156 frames\n", - "process 67000 wavs,29952931 frames\n", - "process 68000 wavs,30360555 frames\n", - "process 69000 wavs,30797929 frames\n", - "process 70000 wavs,31218227 frames\n", - "process 71000 wavs,31663934 frames\n", - "process 72000 wavs,32107468 frames\n", - "process 73000 wavs,32541943 frames\n", - "process 74000 wavs,33010702 frames\n", - "process 75000 wavs,33448082 frames\n", - "process 76000 wavs,33886812 frames\n", - "process 77000 wavs,34338108 frames\n", - "process 78000 wavs,34761495 frames\n", - "process 79000 wavs,35199730 frames\n", - "process 80000 wavs,35669630 frames\n", - "process 81000 wavs,36122402 frames\n", - "process 82000 wavs,36604561 frames\n", - "process 83000 wavs,37085552 frames\n", - "process 84000 wavs,37517500 frames\n", - "process 85000 wavs,37987196 frames\n", - "process 86000 wavs,38415721 frames\n", - "process 87000 wavs,38889467 frames\n", - "process 88000 wavs,39337809 frames\n", - "process 89000 wavs,39792342 frames\n", - "process 90000 wavs,40287946 frames\n", - "process 91000 wavs,40719461 frames\n", - "process 92000 wavs,41178919 frames\n", - "process 93000 wavs,41659635 frames\n", - "process 94000 wavs,42132985 frames\n", - "process 95000 wavs,42584564 frames\n", - "process 96000 wavs,43018598 frames\n", - "process 97000 wavs,43480662 frames\n", - "process 98000 wavs,43973670 frames\n", - "process 99000 wavs,44448190 frames\n", - "process 100000 wavs,44935034 frames\n", - "process 101000 wavs,45379812 frames\n", - "process 102000 wavs,45821207 frames\n", - "process 103000 wavs,46258420 frames\n", - "process 104000 wavs,46743733 frames\n", - "process 105000 wavs,47206922 frames\n", - "process 106000 wavs,47683041 frames\n", - "process 107000 wavs,48122809 frames\n", - "process 108000 wavs,48594623 frames\n", - "process 109000 wavs,49086358 frames\n", - "process 110000 wavs,49525568 frames\n", - "process 111000 wavs,49985820 frames\n", - "process 112000 wavs,50428262 frames\n", - "process 113000 wavs,50897957 frames\n", - "process 114000 wavs,51344589 frames\n", - "process 115000 wavs,51774621 frames\n", - "process 116000 wavs,52243372 frames\n", - "process 117000 wavs,52726025 frames\n", - "process 118000 wavs,53170026 frames\n", - "process 119000 wavs,53614141 frames\n", - "process 120000 wavs,54071271 frames\n" - ] - } - ], - "source": [ - "\n", - "augmentation_pipeline = AugmentationPipeline('{}')\n", - "audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=True,\n", - " target_dB=-20)\n", - "\n", - "def augment_and_featurize(audio_segment):\n", - " augmentation_pipeline.transform_audio(audio_segment)\n", - " return audio_featurizer.featurize(audio_segment)\n", - "\n", - "\n", - "collate_func = CollateFunc()\n", - "\n", - "dataset = AudioDataset(\n", - " args.manifest_path,\n", - " augment_and_featurize, \n", - " args.num_samples)\n", - "\n", - "batch_size = 20\n", - "data_loader = DataLoader(\n", - " dataset,\n", - " batch_size=batch_size,\n", - " shuffle=False,\n", - " num_workers=args.num_workers,\n", - " collate_fn=collate_func)\n", - "\n", - "with paddle.no_grad():\n", - " all_mean_stat = None\n", - " all_var_stat = None\n", - " all_number = 0\n", - " wav_number = 0\n", - " for i, batch in enumerate(data_loader()):\n", - " #for batch in data_loader():\n", - " number, mean_stat, var_stat = batch\n", - " if i == 0:\n", - " all_mean_stat = mean_stat\n", - " all_var_stat = var_stat\n", - " else:\n", - " all_mean_stat += mean_stat\n", - " all_var_stat += var_stat\n", - " all_number += number\n", - " wav_number += batch_size\n", - "\n", - " if wav_number % 1000 == 0:\n", - " print('process {} wavs,{} frames'.format(wav_number,\n", - " all_number))\n", - "\n", - "cmvn_info = {\n", - " 'mean_stat': list(all_mean_stat.tolist()),\n", - " 'var_stat': list(all_var_stat.tolist()),\n", - " 'frame_num': all_number\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "danish-executive", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'mean_stat': [-813852467.7953382, -769025957.9140725, -809499593.411409, -774700574.014532, -750961217.5896736, -760564397.2864963, -805662399.3771614, -843490965.4231446, -850242081.9416809, -857678651.504435, -879067453.9826999, -908602072.3856701, -936850957.7187386, -957242686.489041, -968425442.0916103, -972687545.5953809, -980383731.7683417, -991533337.6343704, -1001966818.1164789, -1010334169.7486078, -1016855066.9099333, -1022176245.7021623, -1025700476.4788507, -1030678878.3195274, -1037075963.124199, -1042705719.0195516, -1047422212.6492896, -1049003537.271861, -1050314833.7453628, -1050772191.0204058, -1050010034.9948177, -1050436065.1336465, -1053327181.7978873, -1058710548.2036785, -1065950852.4966162, -1071709705.0060445, -1077682778.259181, -1083371045.272074, -1089708906.2657735, -1096312217.7865202, -1101089858.8364556, -1104965332.4332569, -1107791702.5223634, -1109431075.2374773, -1110066333.0280604, -1110382732.0722318, -1110480306.3793216, -1110203297.7110727, -1109972534.3583376, -1109378081.8792782, -1108212059.413654, -1107235713.2041805, -1106973581.9280007, -1107352339.7860134, -1108730029.862537, -1110425202.83704, -1113220669.4552443, -1115887535.4870913, -1118105356.3628063, -1120001376.8503075, -1121135822.320366, -1122265971.8751016, -1123990217.401155, -1125786729.6230593, -1127784957.2745507, -1129180108.9033566, -1132000461.6688302, -1134675829.8190608, -1137652487.5164194, -1141755948.0463965, -1145340901.5468378, -1148637682.593287, -1151755522.470022, -1154981643.2268832, -1157417488.840151, -1161240429.0989249, -1165411128.671642, -1170521097.1034513, -1176307165.5109766, -1183456865.0039694, -1190535938.6591117, -1197946309.0472982, -1203596565.037139, -1207563038.1241052, -1209707561.5829782, -1211407066.2452552, -1211884576.9201162, -1212778872.005509, -1214041413.8080075, -1215367953.1745043, -1216850831.482193, -1217678325.5351057, -1218854289.54188, -1219325064.8610544, -1219080344.7580786, -1218541313.657531, -1217889833.2067819, -1216552930.1654336, -1216423777.4113154, -1216575252.225508, -1217075384.9826024, -1217391577.901724, -1217838974.57273, -1218131805.6054134, -1218294889.7465532, -1218566666.1755593, -1218790537.5519717, -1218748668.9956846, -1218603191.4941735, -1218004566.4348054, -1217312410.127734, -1217207493.9522285, -1217284002.3834674, -1217644312.51745, -1218039821.6444128, -1218721811.6269798, -1219121088.9265897, -1219014460.8090584, -1218530127.6776083, -1217952335.451711, -1217316073.8666434, -1217035380.1151958, -1216636431.2964456, -1216257015.2945514, -1215658496.1208403, -1215097272.0976632, -1214669859.2064147, -1214593853.4809475, -1214599475.7838447, -1214575440.823035, -1214158828.8008435, -1213482920.2673717, -1212476577.5897374, -1211251374.2198513, -1210284855.590475, -1209302456.065669, -1209106252.6625297, -1209373211.5146718, -1209689421.7984035, -1210021342.495856, -1210650609.3592312, -1211428521.3900626, -1212616111.4257205, -1213820075.2948189, -1215320588.7144456, -1217175082.2739282, -1219703351.4585004, -1222007827.120464, -1224637375.5900724, -1228367798.912171, -1234853879.862459, -1247222219.867692, -1268562808.1616178, -1302034822.9569275, -1347823631.0776038, -1402753916.9445229, -1458826717.3262982, -1505843092.0970414, -1534278782.249077, -1543955545.8994718, -1600409154.893352], 'var_stat': [12665413908.91729, 11145088801.244318, 12567119446.035736, 11758392758.06822, 11200687982.736668, 11551903443.711124, 12880777868.435602, 14084854368.236998, 14394011058.866192, 14678818621.277662, 15346278722.626339, 16268053979.757076, 17191705347.854794, 17877540386.548733, 18251857849.077663, 18392628178.710472, 18645534548.4045, 19018598212.22902, 19366711357.782673, 19655730286.72857, 19890681996.786858, 20094163350.461906, 20227774955.225887, 20423525628.66887, 20669928826.76939, 20882313568.247944, 21062392676.270527, 21126648821.879055, 21185210734.751118, 21209014745.520447, 21182293842.91236, 21197433134.875977, 21302147790.662144, 21504666657.651955, 21781818550.89697, 21996170165.145462, 22217169779.096275, 22431161762.176693, 22672708668.38104, 22922683961.072956, 23101137011.201683, 23249680793.556847, 23358894817.24979, 23422895267.919228, 23449479198.303394, 23464433357.671055, 23469197140.124596, 23459013479.866177, 23447935341.542686, 23422585038.052387, 23375601301.949135, 23338397991.497776, 23329682884.21905, 23348002892.39853, 23406274659.89975, 23478242518.92228, 23592891371.876236, 23703885161.772205, 23797158601.65954, 23875230355.66992, 23918333664.3946, 23968582109.371258, 24040547318.081936, 24112364295.110058, 24189973697.612144, 24242165205.640236, 24364255205.82311, 24472408850.760197, 24590211203.05312, 24763026764.005527, 24909192634.69144, 25043438176.23281, 25167141466.500504, 25297108031.48665, 25395377064.0999, 25550930772.86505, 25721404827.10336, 25931101211.156487, 26168988710.098465, 26465528802.762875, 26760033029.443783, 27075408488.605213, 27316626931.655052, 27487275073.52796, 27579518448.2332, 27652308513.875782, 27673412508.45838, 27711509210.702576, 27767312240.641487, 27827464683.295334, 27894794590.957966, 27935988489.16511, 27992337099.891083, 28019655483.58796, 28014286886.252903, 27996189233.857716, 27973078840.875465, 27920045013.68706, 27917103211.22359, 27927566165.64652, 27953525818.61368, 27973386070.140022, 27999317832.502476, 28019494120.641834, 28033010746.452637, 28051086123.896503, 28066195174.191753, 28068570977.318798, 28064890246.85437, 28042424375.860577, 28015849655.869568, 28014812222.566605, 28021039053.959835, 28039270607.169422, 28058271295.10199, 28088976520.10178, 28107824988.74732, 28105633030.784756, 28087681357.818607, 28065484299.963837, 28039555887.004284, 28028214431.52875, 28011714871.929447, 27995603790.480755, 27970125897.561134, 27946436130.511288, 27929044772.5522, 27926612443.390316, 27926256324.387302, 27924771848.71099, 27905526922.390133, 27876268519.168198, 27832532606.552593, 27779497699.976765, 27737034351.907337, 27692129825.179924, 27684252911.371475, 27698882622.878677, 27712387157.27985, 27726474638.933037, 27752647691.051613, 27786197932.382797, 27836378752.662235, 27887415700.334576, 27949784230.702114, 28028117657.84245, 28136313097.200474, 28234098926.207996, 28345845477.25874, 28507222800.146496, 28793832339.90449, 29350765483.070816, 30328262350.231213, 31894930713.76519, 34093669067.422382, 36801959396.22739, 39638995447.49344, 42088579425.44825, 43616108982.85117, 44152063315.31461, 47464832889.5967], 'frame_num': 54129649}\n" - ] - } - ], - "source": [ - "print(cmvn_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "accurate-terminal", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "dominant-abuse", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " \n", - "process 1000 wavs,450240 frames\n", - " \n", - "process 2000 wavs,886411 frames\n", - " \n", - "process 3000 wavs,1352580 frames\n", - " \n", - "process 4000 wavs,1814397 frames\n", - " \n", - "process 5000 wavs,2356587 frames\n", - " \n", - "process 6000 wavs,2825310 frames\n", - " \n", - "process 7000 wavs,3272506 frames\n", - " \n", - "process 8000 wavs,3688045 frames\n", - " \n", - "process 9000 wavs,4134669 frames\n", - " \n", - "process 10000 wavs,4586357 frames\n", - " \n", - "process 11000 wavs,5014429 frames\n", - " \n", - "process 12000 wavs,5453334 frames\n", - " \n", - "process 13000 wavs,5892888 frames\n", - " \n", - "process 14000 wavs,6316059 frames\n", - " \n", - "process 15000 wavs,6728870 frames\n", - " \n", - "process 16000 wavs,7199442 frames\n", - " \n", - "process 17000 wavs,7629055 frames\n", - " \n", - "process 18000 wavs,8083729 frames\n", - " \n", - "process 19000 wavs,8519732 frames\n", - " \n", - "process 20000 wavs,8895694 frames\n", - " \n", - "process 21000 wavs,9341778 frames\n", - " \n", - "process 22000 wavs,9796126 frames\n", - " \n", - "process 23000 wavs,10236057 frames\n", - " \n", - "process 24000 wavs,10687461 frames\n", - " \n", - "process 25000 wavs,11113082 frames\n", - " \n", - "process 26000 wavs,11544482 frames\n", - " \n", - "process 27000 wavs,11996273 frames\n", - " \n", - "process 28000 wavs,12456350 frames\n", - " \n", - "process 29000 wavs,12900895 frames\n", - " \n", - "process 30000 wavs,13330353 frames\n", - " \n", - "process 31000 wavs,13736568 frames\n", - " \n", - "process 32000 wavs,14158472 frames\n", - " \n", - "process 33000 wavs,14625316 frames\n", - " \n", - "process 34000 wavs,15036206 frames\n", - " \n", - "process 35000 wavs,15514001 frames\n", - " \n", - "process 36000 wavs,16004323 frames\n", - " \n", - "process 37000 wavs,16418799 frames\n", - " \n", - "process 38000 wavs,16840100 frames\n", - " \n", - "process 39000 wavs,17287752 frames\n", - " \n", - "process 40000 wavs,17776206 frames\n", - " \n", - "process 41000 wavs,18243209 frames\n", - " \n", - "process 42000 wavs,18690449 frames\n", - " \n", - "process 43000 wavs,19137940 frames\n", - " \n", - "process 44000 wavs,19553966 frames\n", - " \n", - "process 45000 wavs,19969813 frames\n", - " \n", - "process 46000 wavs,20440963 frames\n", - " \n", - "process 47000 wavs,20862022 frames\n", - " \n", - "process 48000 wavs,21292801 frames\n", - " \n", - "process 49000 wavs,21713004 frames\n", - " \n", - "process 50000 wavs,22146346 frames\n", - " \n", - "process 51000 wavs,22596172 frames\n", - " \n", - "process 52000 wavs,23074160 frames\n", - " \n", - "process 53000 wavs,23499823 frames\n", - " \n", - "process 54000 wavs,23942151 frames\n", - " \n", - "process 55000 wavs,24390566 frames\n", - " \n", - "process 56000 wavs,24833905 frames\n", - " \n", - "process 57000 wavs,25307270 frames\n", - " \n", - "process 58000 wavs,25748720 frames\n", - " \n", - "process 59000 wavs,26185964 frames\n", - " \n", - "process 60000 wavs,26663953 frames\n", - " \n", - "process 61000 wavs,27117720 frames\n", - " \n", - "process 62000 wavs,27585349 frames\n", - " \n", - "process 63000 wavs,28032693 frames\n", - " \n", - "process 64000 wavs,28487074 frames\n", - " \n", - "process 65000 wavs,28956462 frames\n", - " \n", - "process 66000 wavs,29436358 frames\n", - " \n", - "process 67000 wavs,29918569 frames\n", - " \n", - "process 68000 wavs,30325682 frames\n", - " \n", - "process 69000 wavs,30762528 frames\n", - " \n", - "process 70000 wavs,31182319 frames\n", - " \n", - "process 71000 wavs,31627526 frames\n", - " \n", - "process 72000 wavs,32070556 frames\n", - " \n", - "process 73000 wavs,32504534 frames\n", - " \n", - "process 74000 wavs,32972775 frames\n", - " \n", - "process 75000 wavs,33409637 frames\n", - " \n", - "process 76000 wavs,33847861 frames\n", - " \n", - "process 77000 wavs,34298647 frames\n", - " \n", - "process 78000 wavs,34721536 frames\n", - " \n", - "process 79000 wavs,35159236 frames\n", - " \n", - "process 80000 wavs,35628628 frames\n", - " \n", - "process 81000 wavs,36080909 frames\n", - " \n", - "process 82000 wavs,36562496 frames\n", - " \n", - "process 83000 wavs,37042976 frames\n", - " \n", - "process 84000 wavs,37474403 frames\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " \n", - "process 85000 wavs,37943596 frames\n", - " \n", - "process 86000 wavs,38371620 frames\n", - " \n", - "process 87000 wavs,38844874 frames\n", - " \n", - "process 88000 wavs,39292686 frames\n", - " \n", - "process 89000 wavs,39746715 frames\n", - " \n", - "process 90000 wavs,40241800 frames\n", - " \n", - "process 91000 wavs,40672817 frames\n", - " \n", - "process 92000 wavs,41131773 frames\n", - " \n", - "process 93000 wavs,41612001 frames\n", - " \n", - "process 94000 wavs,42084822 frames\n", - " \n", - "process 95000 wavs,42535878 frames\n", - " \n", - "process 96000 wavs,42969365 frames\n", - " \n", - "process 97000 wavs,43430890 frames\n", - " \n", - "process 98000 wavs,43923378 frames\n", - " \n", - "process 99000 wavs,44397370 frames\n", - " \n", - "process 100000 wavs,44883695 frames\n", - " \n", - "process 101000 wavs,45327968 frames\n", - " \n", - "process 102000 wavs,45768860 frames\n", - " \n", - "process 103000 wavs,46205602 frames\n", - " \n", - "process 104000 wavs,46690407 frames\n", - " \n", - "process 105000 wavs,47153089 frames\n", - " \n", - "process 106000 wavs,47628699 frames\n", - " \n", - "process 107000 wavs,48067945 frames\n", - " \n", - "process 108000 wavs,48539256 frames\n", - " \n", - "process 109000 wavs,49030485 frames\n", - " \n", - "process 110000 wavs,49469189 frames\n", - " \n", - "process 111000 wavs,49928968 frames\n", - " \n", - "process 112000 wavs,50370921 frames\n", - " \n", - "process 113000 wavs,50840090 frames\n", - " \n", - "process 114000 wavs,51286249 frames\n", - " \n", - "process 115000 wavs,51715786 frames\n", - " \n", - "process 116000 wavs,52184017 frames\n", - " \n", - "process 117000 wavs,52666156 frames\n", - " \n", - "process 118000 wavs,53109645 frames\n", - " \n", - "process 119000 wavs,53553253 frames\n", - " \n", - "process 120000 wavs,54009877 frames\n", - "{'mean_stat': [700612678.1184504, 704246512.9321843, 720430663.1822729, 754033269.0474415, 798737761.616614, 829467218.4204571, 851246702.9426627, 862261185.2661449, 859339943.6923889, 846303730.8696194, 832995109.605447, 823196536.6029147, 832626008.2569772, 845571326.1936859, 848801373.0562981, 846503549.328017, 836774344.5500796, 823481091.0445303, 820728368.2518216, 804571348.4957463, 795306095.0083207, 811729024.2415155, 805734803.5703195, 813076782.1959459, 806620199.406499, 809655573.8886961, 804371708.9347517, 809272248.6085774, 810322689.7490631, 814294131.1973915, 816262716.0476038, 816213124.2411841, 817158473.4380915, 821414211.5629157, 827408091.5728914, 834353896.0519086, 840094990.3467333, 842613218.6554606, 842070761.1727513, 834970952.5260613, 837020570.8200948, 829592602.7833654, 830116543.8893851, 829482316.3881509, 833397219.4597517, 839251633.3120549, 845475010.4718693, 852378426.7183967, 859563981.8633184, 866063840.5523493, 867790921.9978689, 868215100.5962687, 869683066.032885, 872467375.6674014, 873097681.1780069, 873025823.0543871, 869897292.7201596, 866386426.3869117, 863166726.7256871, 854653071.2244718, 842402803.9000899, 830838253.4144138, 830143002.3536818, 831492285.0310817, 833304371.8781006, 838896092.8621838, 843866088.9578133, 847316792.1429776, 851038022.3643295, 855931698.0149751, 859320543.9795249, 863031001.3470656, 868325062.1832993, 873626971.0115026, 878726636.924209, 884861725.972504, 886920281.5192285, 883056006.5094173, 863719240.7255149, 773378975.9476194], 'var_stat': [9237018652.657722, 9417257721.82426, 10105084297.159702, 11071318522.587782, 12422783727.426847, 13400306419.784964, 14148498843.406874, 14576436982.89939, 14529009036.494726, 14105645932.596651, 13682988821.478252, 13413013425.088106, 13764134927.293928, 14233704806.737064, 14361631309.367067, 14281358385.45644, 13939662689.213865, 13496884231.929493, 13382566162.783987, 12871350930.6626, 12576198160.876635, 13051463889.56708, 12859205935.513906, 13053861416.098743, 12830323588.550724, 12886405923.897238, 12708529922.84171, 12847306110.231739, 12880398489.53404, 13002566299.565536, 13066708060.463543, 13064231286.858614, 13088983337.353497, 13221393824.891022, 13412425607.755072, 13631485149.777075, 13807797519.156103, 13877277485.033077, 13848613909.96762, 13609176326.2529, 13649815250.130072, 13397698404.696907, 13388964704.359968, 13354326914.968012, 13469861474.898457, 13652539440.283333, 13846837321.329163, 14062143714.601675, 14292571198.61228, 14504626563.299246, 14563864749.132776, 14579720287.991764, 14626700787.353922, 14716185568.128899, 14728532777.28015, 14719101187.113443, 14607945896.239174, 14478517828.531614, 14355110561.681187, 14057430280.249746, 13634284490.879377, 13248236002.494394, 13217602306.335958, 13257856701.946049, 13323688441.072674, 13515395318.023148, 13685827169.67645, 13811622609.426846, 13947347160.615082, 14115883822.884943, 14231204526.433033, 14356066668.651815, 14533604268.238445, 14708971788.69237, 14875667326.732443, 15079098318.79331, 15144888989.667963, 15002658970.504765, 14349232841.34513, 11544480117.013124], 'frame_num': 54068199}\n" - ] - } - ], - "source": [ - "import random\n", - "\n", - "import numpy as np\n", - "import paddle\n", - "from paddle.io import DataLoader\n", - "from paddle.io import Dataset\n", - "\n", - "from deepspeech.frontend.audio import AudioSegment\n", - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.frontend.utility import read_manifest\n", - "\n", - "# https://github.com/PaddlePaddle/Paddle/pull/31481\n", - "class CollateFunc(object):\n", - " ''' Collate function for AudioDataset\n", - " '''\n", - " def __init__(self, feature_func):\n", - " self.feature_func = feature_func\n", - " \n", - " def __call__(self, batch):\n", - " mean_stat = None\n", - " var_stat = None\n", - " number = 0\n", - " for item in batch:\n", - " audioseg = AudioSegment.from_file(item['feat'])\n", - " feat = self.feature_func(audioseg) #(D, T)\n", - "\n", - " sums = np.sum(feat, axis=1)\n", - " if mean_stat is None:\n", - " mean_stat = sums\n", - " else:\n", - " mean_stat += sums\n", - "\n", - " square_sums = np.sum(np.square(feat), axis=1)\n", - " if var_stat is None:\n", - " var_stat = square_sums\n", - " else:\n", - " var_stat += square_sums\n", - "\n", - " number += feat.shape[1]\n", - " return number, mean_stat, var_stat\n", - "\n", - "\n", - "class AudioDataset(Dataset):\n", - " def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):\n", - " self._rng = rng if rng else np.random.RandomState(random_seed)\n", - " manifest = read_manifest(manifest_path)\n", - " if num_samples == -1:\n", - " sampled_manifest = manifest\n", - " else:\n", - " sampled_manifest = self._rng.choice(manifest, num_samples, replace=False)\n", - " self.items = sampled_manifest\n", - "\n", - " def __len__(self):\n", - " return len(self.items)\n", - "\n", - " def __getitem__(self, idx):\n", - " return self.items[idx]\n", - " \n", - " \n", - "augmentation_pipeline = AugmentationPipeline('{}')\n", - "audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=True,\n", - " target_dB=-20)\n", - "\n", - "def augment_and_featurize(audio_segment):\n", - " augmentation_pipeline.transform_audio(audio_segment)\n", - " return audio_featurizer.featurize(audio_segment)\n", - "\n", - "\n", - "collate_func = CollateFunc(augment_and_featurize)\n", - "\n", - "dataset = AudioDataset(\n", - " args.manifest_path,\n", - " args.num_samples)\n", - "\n", - "batch_size = 20\n", - "data_loader = DataLoader(\n", - " dataset,\n", - " batch_size=batch_size,\n", - " shuffle=False,\n", - " num_workers=args.num_workers,\n", - " collate_fn=collate_func)\n", - "\n", - "with paddle.no_grad():\n", - " all_mean_stat = None\n", - " all_var_stat = None\n", - " all_number = 0\n", - " wav_number = 0\n", - " for i, batch in enumerate(data_loader):\n", - " number, mean_stat, var_stat = batch\n", - " if i == 0:\n", - " all_mean_stat = mean_stat\n", - " all_var_stat = var_stat\n", - " else:\n", - " all_mean_stat += mean_stat\n", - " all_var_stat += var_stat\n", - " all_number += number\n", - " wav_number += batch_size\n", - "\n", - " if wav_number % 1000 == 0:\n", - " print('process {} wavs,{} frames'.format(wav_number,\n", - " all_number))\n", - "\n", - "cmvn_info = {\n", - " 'mean_stat': list(all_mean_stat.tolist()),\n", - " 'var_stat': list(all_var_stat.tolist()),\n", - " 'frame_num': all_number\n", - "}\n", - "print(cmvn_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "unlike-search", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/dataloader.ipynb b/.notebook/dataloader.ipynb deleted file mode 100644 index 3de8f64a9..000000000 --- a/.notebook/dataloader.ipynb +++ /dev/null @@ -1,389 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "emerging-meter", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " long_ = _make_signed(np.long)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " ulong = _make_unsigned(np.long)\n" - ] - } - ], - "source": [ - "import math\n", - "import random\n", - "import tarfile\n", - "import logging\n", - "import numpy as np\n", - "from collections import namedtuple\n", - "from functools import partial\n", - "\n", - "import paddle\n", - "from paddle.io import Dataset\n", - "from paddle.io import DataLoader\n", - "from paddle.io import BatchSampler\n", - "from paddle.io import DistributedBatchSampler\n", - "from paddle import distributed as dist\n", - "\n", - "from data_utils.utility import read_manifest\n", - "from data_utils.augmentor.augmentation import AugmentationPipeline\n", - "from data_utils.featurizer.speech_featurizer import SpeechFeaturizer\n", - "from data_utils.speech import SpeechSegment\n", - "from data_utils.normalizer import FeatureNormalizer\n", - "\n", - "\n", - "from data_utils.dataset import (\n", - " DeepSpeech2Dataset,\n", - " DeepSpeech2DistributedBatchSampler,\n", - " DeepSpeech2BatchSampler,\n", - " SpeechCollator,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "excessive-american", - "metadata": {}, - "outputs": [], - "source": [ - "def create_dataloader(manifest_path,\t\n", - " vocab_filepath,\t\n", - " mean_std_filepath,\t\n", - " augmentation_config='{}',\t\n", - " max_duration=float('inf'),\t\n", - " min_duration=0.0,\t\n", - " stride_ms=10.0,\t\n", - " window_ms=20.0,\t\n", - " max_freq=None,\t\n", - " specgram_type='linear',\t\n", - " use_dB_normalization=True,\t\n", - " random_seed=0,\t\n", - " keep_transcription_text=False,\t\n", - " is_training=False,\t\n", - " batch_size=1,\t\n", - " num_workers=0,\t\n", - " sortagrad=False,\t\n", - " shuffle_method=None,\t\n", - " dist=False):\t\n", - "\n", - " dataset = DeepSpeech2Dataset(\t\n", - " manifest_path,\t\n", - " vocab_filepath,\t\n", - " mean_std_filepath,\t\n", - " augmentation_config=augmentation_config,\t\n", - " max_duration=max_duration,\t\n", - " min_duration=min_duration,\t\n", - " stride_ms=stride_ms,\t\n", - " window_ms=window_ms,\t\n", - " max_freq=max_freq,\t\n", - " specgram_type=specgram_type,\t\n", - " use_dB_normalization=use_dB_normalization,\t\n", - " random_seed=random_seed,\t\n", - " keep_transcription_text=keep_transcription_text)\t\n", - "\n", - " if dist:\t\n", - " batch_sampler = DeepSpeech2DistributedBatchSampler(\t\n", - " dataset,\t\n", - " batch_size,\t\n", - " num_replicas=None,\t\n", - " rank=None,\t\n", - " shuffle=is_training,\t\n", - " drop_last=is_training,\t\n", - " sortagrad=is_training,\t\n", - " shuffle_method=shuffle_method)\t\n", - " else:\t\n", - " batch_sampler = DeepSpeech2BatchSampler(\t\n", - " dataset,\t\n", - " shuffle=is_training,\t\n", - " batch_size=batch_size,\t\n", - " drop_last=is_training,\t\n", - " sortagrad=is_training,\t\n", - " shuffle_method=shuffle_method)\t\n", - "\n", - " def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):\t\n", - " \"\"\"\t\n", - " Padding audio features with zeros to make them have the same shape (or\t\n", - " a user-defined shape) within one bach.\t\n", - "\n", - " If ``padding_to`` is -1, the maximun shape in the batch will be used\t\n", - " as the target shape for padding. Otherwise, `padding_to` will be the\t\n", - " target shape (only refers to the second axis).\t\n", - "\n", - " If `flatten` is True, features will be flatten to 1darray.\t\n", - " \"\"\"\t\n", - " new_batch = []\t\n", - " # get target shape\t\n", - " max_length = max([audio.shape[1] for audio, text in batch])\t\n", - " if padding_to != -1:\t\n", - " if padding_to < max_length:\t\n", - " raise ValueError(\"If padding_to is not -1, it should be larger \"\t\n", - " \"than any instance's shape in the batch\")\t\n", - " max_length = padding_to\t\n", - " max_text_length = max([len(text) for audio, text in batch])\t\n", - " # padding\t\n", - " padded_audios = []\t\n", - " audio_lens = []\t\n", - " texts, text_lens = [], []\t\n", - " for audio, text in batch:\t\n", - " padded_audio = np.zeros([audio.shape[0], max_length])\t\n", - " padded_audio[:, :audio.shape[1]] = audio\t\n", - " if flatten:\t\n", - " padded_audio = padded_audio.flatten()\t\n", - " padded_audios.append(padded_audio)\t\n", - " audio_lens.append(audio.shape[1])\t\n", - "\n", - " padded_text = np.zeros([max_text_length])\n", - " if is_training:\n", - " padded_text[:len(text)] = text\t# ids\n", - " else:\n", - " padded_text[:len(text)] = [ord(t) for t in text] # string\n", - " \n", - " texts.append(padded_text)\t\n", - " text_lens.append(len(text))\t\n", - "\n", - " padded_audios = np.array(padded_audios).astype('float32')\t\n", - " audio_lens = np.array(audio_lens).astype('int64')\t\n", - " texts = np.array(texts).astype('int32')\t\n", - " text_lens = np.array(text_lens).astype('int64')\t\n", - " return padded_audios, texts, audio_lens, text_lens\t\n", - "\n", - " loader = DataLoader(\t\n", - " dataset,\t\n", - " batch_sampler=batch_sampler,\t\n", - " collate_fn=partial(padding_batch, is_training=is_training),\t\n", - " num_workers=num_workers)\t\n", - " return loader" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "naval-brave", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('infer_manifest', str,\n", - " 'examples/aishell/data/manifest.dev',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " 'examples/aishell/data/mean_std.npz',\n", - " \"Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/aishell/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/aishell/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'linear',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc'])\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "bearing-physics", - "metadata": {}, - "outputs": [], - "source": [ - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " augmentation_config='{}',\n", - " #max_duration=float('inf'),\n", - " max_duration=27.0,\n", - " min_duration=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=True,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "classified-melissa", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[22823, 26102, 20195, 37324, 0 , 0 ],\n", - " [22238, 26469, 23601, 22909, 0 , 0 ],\n", - " [20108, 26376, 22235, 26085, 0 , 0 ],\n", - " [36824, 35201, 20445, 25345, 32654, 24863],\n", - " [29042, 27748, 21463, 23456, 0 , 0 ]])\n", - "test raw 大时代里\n", - "test raw 煲汤受宠\n", - "audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [163, 167, 180, 186, 186])\n", - "test len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [4, 4, 4, 6, 4])\n", - "audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n", - " [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n", - " [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n", - " [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n", - " [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n", - " [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n", - " [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n", - " [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n", - " [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n", - " [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n", - " [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n", - " [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n", - " [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n", - " [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n", - " [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n", - " ...,\n", - " [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n", - " [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n", - " [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n", - "\n", - " [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n", - " [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n", - " [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n", - " ...,\n", - " [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n", - " [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n", - " [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n" - ] - } - ], - "source": [ - "for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test', text)\n", - " print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", - " print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", - " print('audio len', audio_len)\n", - " print('test len', text_len)\n", - " print('audio', audio)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "unexpected-skating", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "minus-modern", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file diff --git a/.notebook/dataloader_with_tokens_tokenids.ipynb b/.notebook/dataloader_with_tokens_tokenids.ipynb deleted file mode 100644 index 7d93dd009..000000000 --- a/.notebook/dataloader_with_tokens_tokenids.ipynb +++ /dev/null @@ -1,1204 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "medieval-monday", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/DeepSpeech-2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "emerging-meter", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n" - ] - } - ], - "source": [ - "import math\n", - "import random\n", - "import tarfile\n", - "import logging\n", - "import numpy as np\n", - "from collections import namedtuple\n", - "from functools import partial\n", - "\n", - "import paddle\n", - "from paddle.io import Dataset\n", - "from paddle.io import DataLoader\n", - "from paddle.io import BatchSampler\n", - "from paddle.io import DistributedBatchSampler\n", - "from paddle import distributed as dist\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "excessive-american", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "naval-brave", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:93] register user softmax to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:97] register user log_softmax to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:101] register user sigmoid to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:105] register user log_sigmoid to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:109] register user relu to paddle, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:119] override cat of paddle if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:133] override item of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:144] override long of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:164] override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:179] override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:185] override eq of paddle if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:195] override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:212] override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:223] register user view to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:233] register user view_as to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:259] register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:277] register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:288] register user fill_ to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:298] register user repeat to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:303] register user softmax to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:308] register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:312] register user relu to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:322] register user type_as to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:337] register user to to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:346] register user float to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:356] register user tolist to paddle.Tensor, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:371] register user glu to paddle.nn.functional, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:422] override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:428] register user Module to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:434] register user ModuleList to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:450] register user GLU to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:483] register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "[WARNING 2021/04/16 06:32:09 __init__.py:489] register user export to paddle.jit, remove this when fixed!\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/tiny/s1/data/spm_bpe', 'infer_manifest': 'examples/tiny/s1/data/manifest.tiny', 'mean_std_path': 'examples/tiny/s1/data/mean_std.npz', 'vocab_path': 'examples/tiny/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/tiny/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('unit_type', str,\n", - " 'char',\n", - " \"Options: char, word, spm.\",\n", - " choices=['char', 'word', 'spm'])\n", - "add_arg('spm_model_prefix', str,\n", - " 'examples/tiny/s1/data/spm_bpe',\n", - " \"spm model prefix.\",)\n", - "add_arg('infer_manifest', str,\n", - " 'examples/tiny/s1/data/manifest.tiny',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " 'examples/tiny/s1/data/mean_std.npz',\n", - " \"Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/tiny/s1/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/tiny/s1/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc'])\n", - "add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n", - "add_arg('delta_delta', bool, False, \"delta delta\")\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "wired-principal", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'unit_type': 'char', 'spm_model_prefix': 'examples/aishell/s1/data/spm_bpe', 'infer_manifest': 'examples/aishell/s1/data/manifest.test', 'mean_std_path': '', 'vocab_path': 'examples/aishell/s1/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/s1/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('unit_type', str,\n", - " 'char',\n", - " \"Options: char, word, spm.\",\n", - " choices=['char', 'word', 'spm'])\n", - "add_arg('spm_model_prefix', str,\n", - " 'examples/aishell/s1/data/spm_bpe',\n", - " \"spm model prefix.\",)\n", - "add_arg('infer_manifest', str,\n", - " 'examples/aishell/s1/data/manifest.test',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " '',\n", - " \"examples/aishell/s1/data/mean_std.npz, Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/aishell/s1/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/aishell/s1/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc', 'fbank'])\n", - "add_arg('feat_dim', int, 80, \"mfcc or fbank feat dim.\")\n", - "add_arg('delta_delta', bool, False, \"delta delta\")\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "bearing-physics", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " long_ = _make_signed(np.long)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " ulong = _make_unsigned(np.long)\n" - ] - } - ], - "source": [ - "from deepspeech.frontend.utility import read_manifest\n", - "from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n", - "from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer\n", - "from deepspeech.frontend.speech import SpeechSegment\n", - "from deepspeech.frontend.normalizer import FeatureNormalizer\n", - "\n", - "\n", - "from deepspeech.io.collator import SpeechCollator\n", - "from deepspeech.io.dataset import ManifestDataset\n", - "from deepspeech.io.sampler import (\n", - " SortagradDistributedBatchSampler,\n", - " SortagradBatchSampler,\n", - ")\n", - "from deepspeech.io import create_dataloader\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " unit_type=args.unit_type,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " spm_model_prefix=args.spm_model_prefix,\n", - " augmentation_config='{}',\n", - " max_input_len=27.0,\n", - " min_input_len=0.0,\n", - " max_output_len=float('inf'),\n", - " min_output_len=0.0,\n", - " max_output_input_ratio=float('inf'),\n", - " min_output_input_ratio=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=True,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " num_workers=0,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "classified-melissa", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "fbank\n", - "[232 387 331 ... 249 249 262] int16\n", - "fbank\n", - "[-138 -219 -192 ... 338 324 351] int16\n", - "fbank\n", - "[ 694 1175 1022 ... 553 514 627] int16\n", - "fbank\n", - "[-39 -79 -53 ... 139 172 99] int16\n", - "fbank\n", - "[-277 -480 -425 ... 758 767 739] int16\n", - "fbank\n", - "[ 399 693 609 ... 1291 1270 1291] int16\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " if arr.dtype == np.object:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "fbank\n", - "[ -750 -1254 -1107 ... 2276 1889 2067] int16\n", - "fbank\n", - "[ -127 -199 -149 ... -5243 -5065 -5398] int16\n", - "fbank\n", - "[ 465 783 677 ... 980 903 1008] int16\n", - "fbank\n", - "[ 90 160 157 ... -2 -16 -21] int16\n", - "fbank\n", - "[ 213 345 295 ... 2483 2246 2501] int16\n", - "fbank\n", - "[ -86 -159 -131 ... 270 258 290] int16\n", - "fbank\n", - "[-1023 -1714 -1505 ... 1532 1596 1575] int16\n", - "fbank\n", - "[-366 -602 -527 ... 374 370 379] int16\n", - "fbank\n", - "[ 761 1275 1127 ... 369 413 295] int16\n", - "fbank\n", - "[382 621 550 ... 161 161 174] int16\n", - "fbank\n", - "[ -28 -91 -120 ... 28 34 11] int16\n", - "fbank\n", - "[ -5 -5 -5 ... 268 294 341] int16\n", - "fbank\n", - "[240 417 684 ... 267 262 219] int16\n", - "fbank\n", - "[131 206 194 ... 383 320 343] int16\n", - "test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[31069, 21487, 29233, 30340, 20320, -1 , -1 ],\n", - " [20540, 24471, 19968, 25552, 30340, 26159, -1 ],\n", - " [36825, 20010, 31243, 24230, 26159, 32654, 30340],\n", - " [20108, 21040, 20108, -1 , -1 , -1 , -1 ],\n", - " [21435, 34892, 25919, 21270, -1 , -1 , -1 ]])\n", - "fbank\n", - "[1155 1890 1577 ... 1092 989 1130] int16\n", - "fbank\n", - "[296 358 296 ... 140 140 168] int16\n", - "fbank\n", - "[-50 -91 -63 ... 104 104 86] int16\n", - "fbank\n", - "[-37 -66 -50 ... -31 -45 -52] int16\n", - "fbank\n", - "[-401 -652 -547 ... -339 -307 -344] int16\n", - "fbank\n", - "[-21 -47 -51 ... 94 81 107] int16\n", - "fbank\n", - "[ 533 887 755 ... 3074 2853 3254] int16\n", - "fbank\n", - "[ 44 71 66 ... -628 -733 -601] int16\n", - "fbank\n", - "[ 50 86 79 ... 129 116 138] int16\n", - "fbank\n", - "[ 92 146 126 ... -208 -193 -179] int16\n", - "test raw: 祝可爱的你\n", - "test raw: 去行政化\n", - "audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [184, 194, 196, 204, 207])\n", - "test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [5, 6, 7, 3, 4])\n", - "audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[12.25633812, 12.61639309, 10.36936474, ..., 13.02949619, 11.51365757, 10.59789085],\n", - " [13.32148266, 13.41071606, 11.43800735, ..., 13.69783783, 12.83939362, 11.51259613],\n", - " [12.62640572, 12.53621101, 10.97212505, ..., 13.33757591, 12.32293034, 10.75493717],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[10.99619484, 11.35202599, 9.56922054 , ..., 9.94971657 , 9.88354111 , 9.55315971 ],\n", - " [10.44461155, 9.81688595 , 5.62538481 , ..., 10.60468388, 10.94417381, 9.42646980 ],\n", - " [10.23835754, 10.23407459, 7.99464273 , ..., 10.68097591, 9.91640091 , 10.04131031],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[14.10299397, 14.50298119, 12.87738323, ..., 12.62796497, 12.69949627, 11.43171215],\n", - " [13.85035992, 13.15289116, 10.66541386, ..., 13.34364223, 13.46972179, 11.02160740],\n", - " [13.19866467, 13.23537827, 11.65760899, ..., 12.72559357, 12.42716217, 11.74562359],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.85668373, 12.82431412, 11.68144703, ..., 14.10119247, 15.12791920, 13.68221378],\n", - " [13.19507027, 13.40244961, 11.43618393, ..., 13.32919979, 13.68267441, 12.73429012],\n", - " [13.02173328, 12.92082500, 11.44303989, ..., 12.77793121, 13.10915661, 11.77327728],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.90771198, 13.40234852, 13.01435471, ..., 13.80359459, 14.08088684, 13.17883396],\n", - " [14.06678009, 14.06943512, 12.52837276, ..., 13.66423225, 13.66300583, 13.60142994],\n", - " [12.58743191, 12.94520760, 11.75190544, ..., 14.28828907, 14.08229160, 13.02433395],\n", - " ...,\n", - " [16.20896912, 16.42283821, 14.94358730, ..., 12.91146755, 12.66766262, 11.76361752],\n", - " [13.49324894, 14.14653301, 13.16490936, ..., 13.23435783, 13.45378494, 12.60386276],\n", - " [15.56288910, 15.92445087, 14.90794277, ..., 13.43840790, 13.41075516, 12.55605984]]])\n" - ] - } - ], - "source": [ - "for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test:', text)\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", - " print('audio len:', audio_len)\n", - " print('test len:', text_len)\n", - " print('audio:', audio)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "unexpected-skating", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "minus-modern", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "fbank\n", - "[232 387 331 ... 249 249 262] int16\n", - "fbank\n", - "[-138 -219 -192 ... 338 324 351] int16\n", - "fbank\n", - "[ 694 1175 1022 ... 553 514 627] int16\n", - "fbank\n", - "[-39 -79 -53 ... 139 172 99] int16\n", - "fbank\n", - "[-277 -480 -425 ... 758 767 739] int16\n", - "fbank\n", - "test: Tensor(shape=[5, 7], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[2695, 505, 2332, 2553, 169, -1 , -1 ],\n", - " [ 230, 1237, 2 , 1556, 2553, 1694, -1 ],\n", - " [3703, 28 , 2739, 1172, 1694, 2966, 2553],\n", - " [ 70 , 355, 70 , -1 , -1 , -1 , -1 ],\n", - " [ 477, 3363, 1621, 412, -1 , -1 , -1 ]])\n", - "[ 399 693 609 ... 1291 1270 1291] int16\n", - "test raw: ઇǹज৹©\n", - "test raw: ǝണٕƜ\n", - "test len: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [5, 6, 7, 3, 4])\n", - "audio: Tensor(shape=[5, 207, 80], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[12.25794601, 12.61855793, 10.37306023, ..., 13.12571049, 11.53678799, 10.32210350],\n", - " [13.32333183, 13.41336918, 11.44248962, ..., 13.65861225, 12.79308128, 11.31168747],\n", - " [12.62584686, 12.53506088, 10.96861362, ..., 13.32526493, 12.41560936, 10.71458912],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[11.00003052, 11.35529137, 9.56384087 , ..., 10.06063652, 10.16322994, 9.43149185 ],\n", - " [10.44556236, 9.81155300 , 5.49400425 , ..., 10.84116268, 11.02734756, 9.42253590 ],\n", - " [10.23620510, 10.23321152, 7.99466419 , ..., 10.93381882, 10.28395081, 10.00841141],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[14.10379314, 14.50375748, 12.87825108, ..., 12.68065739, 12.62359715, 11.53773308],\n", - " [13.84964657, 13.15079498, 10.67198086, ..., 13.24875164, 13.45796680, 10.97363472],\n", - " [13.19808197, 13.23482990, 11.65900230, ..., 12.70375061, 12.41395664, 11.88668156],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.85676289, 12.82410812, 11.67961884, ..., 14.12018299, 15.14850044, 13.80065727],\n", - " [13.19532776, 13.40243340, 11.43492508, ..., 13.29144669, 13.70278549, 12.67841339],\n", - " [13.02196407, 12.92111111, 11.43998623, ..., 12.71165752, 13.16518497, 11.92028046],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[12.90661621, 13.40162563, 13.01394463, ..., 13.84056377, 14.11240959, 13.21227264],\n", - " [14.06642914, 14.06922340, 12.52955723, ..., 13.55829811, 13.60157204, 13.50268650],\n", - " [12.58881378, 12.94780254, 11.75758171, ..., 14.29055786, 14.12165928, 13.02695847],\n", - " ...,\n", - " [16.20891571, 16.42290306, 14.94398117, ..., 12.86083794, 12.63515949, 11.67581463],\n", - " [13.49345875, 14.14656067, 13.16498375, ..., 13.28024578, 13.40956783, 12.70357513],\n", - " [15.56265163, 15.92387581, 14.90643024, ..., 13.45694065, 13.44703197, 12.81099033]]])\n", - "audio len: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [184, 194, 196, 204, 207])\n" - ] - } - ], - "source": [ - "keep_transcription_text=False\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " unit_type=args.unit_type,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " spm_model_prefix=args.spm_model_prefix,\n", - " augmentation_config='{}',\n", - " max_input_len=27.0,\n", - " min_input_len=0.0,\n", - " max_output_len=float('inf'),\n", - " min_output_len=0.0,\n", - " max_output_input_ratio=float('inf'),\n", - " min_output_input_ratio=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=keep_transcription_text,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " num_workers=0,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)\n", - "for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test:', text)\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n", - " print(\"test raw:\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n", - " print('test len:', text_len)\n", - " print('audio:', audio)\n", - " print('audio len:', audio_len)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "competitive-mounting", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "knowing-military", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 1, 'specgram_type': 'fbank', 'feat_dim': 80, 'delta_delta': False, 'stride_ms': 10.0, 'window_ms': 25.0, 'sample_rate': 16000, 'manifest_path': 'examples/aishell/s1/data/manifest.train', 'output_path': 'examples/aishell/s1/data/mean_std.npz'}\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "\n", - "add_arg('num_samples', int, 1, \"# of samples to for statistics.\")\n", - "add_arg('specgram_type', str, 'fbank',\n", - " \"Audio feature type. Options: linear, mfcc, fbank.\",\n", - " choices=['linear', 'mfcc', 'fbank'])\n", - "add_arg('feat_dim', int, 80, \"Audio feature dim.\")\n", - "add_arg('delta_delta', bool, False,\"Audio feature with delta delta.\")\n", - "add_arg('stride_ms', float, 10.0, \"stride length in ms.\")\n", - "add_arg('window_ms', float, 25.0, \"stride length in ms.\")\n", - "add_arg('sample_rate', int, 16000, \"target sample rate.\")\n", - "add_arg('manifest_path', str,\n", - " 'examples/aishell/s1/data/manifest.train',\n", - " \"Filepath of manifest to compute normalizer's mean and stddev.\")\n", - "add_arg('output_path', str,\n", - " 'examples/aishell/s1/data/mean_std.npz',\n", - " \"Filepath of write mean and stddev to (.npz).\")\n", - "args = parser.parse_args([])\n", - "print(vars(args))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "unnecessary-province", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline\n", - "from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer\n", - "from deepspeech.frontend.normalizer import FeatureNormalizer\n", - "from deepspeech.frontend.audio import AudioSegment\n", - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.frontend.utility import read_manifest\n", - "\n", - "\n", - "\n", - "def mean(args):\n", - " augmentation_pipeline = AugmentationPipeline('{}')\n", - " audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=True,\n", - " target_dB=-20,\n", - " dither=0.0)\n", - "\n", - " def augment_and_featurize(audio_segment):\n", - " augmentation_pipeline.transform_audio(audio_segment)\n", - " return audio_featurizer.featurize(audio_segment)\n", - "\n", - " normalizer = FeatureNormalizer(\n", - " mean_std_filepath=None,\n", - " manifest_path=args.manifest_path,\n", - " featurize_func=augment_and_featurize,\n", - " num_samples=args.num_samples)\n", - " normalizer.write_to_file(args.output_path)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "interested-camping", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n", - "[54. 90. 77. ... 58. 58. 61.]\n", - "29746\n", - "fbank\n", - "[54 90 77 ... 58 58 61] int16\n", - "(184, 80) float64\n", - "[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n", - " 7.80671114]\n", - " [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n", - " 8.83860079]\n", - " [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n", - " 7.96168491]\n", - " ...\n", - " [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n", - " 8.75616521]\n", - " [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n", - " 7.69287184]\n", - " [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n", - " 8.16007108]]\n", - "(184, 80) float64\n", - "[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n", - " 7.80671114]\n", - " [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n", - " 8.83860079]\n", - " [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n", - " 7.96168491]\n", - " ...\n", - " [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n", - " 8.75616521]\n", - " [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n", - " 7.69287184]\n", - " [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n", - " 8.16007108]]\n" - ] - } - ], - "source": [ - "wav='/workspace/DeepSpeech-2.x/examples/aishell/s1/../../..//examples/dataset/aishell/data_aishell/wav/test/S0916/BAC009S0916W0426.wav'\n", - "test='祝可爱的你'\n", - "audio_featurizer = AudioFeaturizer(\n", - " specgram_type=args.specgram_type,\n", - " feat_dim=args.feat_dim,\n", - " delta_delta=args.delta_delta,\n", - " stride_ms=args.stride_ms,\n", - " window_ms=args.window_ms,\n", - " n_fft=None,\n", - " max_freq=None,\n", - " target_sample_rate=args.sample_rate,\n", - " use_dB_normalization=False,\n", - " target_dB=-20,\n", - " dither=0.0)\n", - "samples = AudioSegment.from_file(wav)\n", - "print(samples._samples)\n", - "print(samples._samples * 2**15)\n", - "print(len(samples._samples))\n", - "feat = audio_featurizer.featurize(samples, False, False)\n", - "feat = feat.T\n", - "print(feat.shape, feat.dtype)\n", - "print(feat)\n", - "\n", - "from python_speech_features import logfbank\n", - "max_freq = args.sample_rate / 2\n", - "fbank_feat = logfbank(\n", - " signal=samples.to('int16'),\n", - " samplerate=args.sample_rate,\n", - " winlen=0.001 * args.window_ms,\n", - " winstep=0.001 * args.stride_ms,\n", - " nfilt=args.feat_dim,\n", - " nfft=512,\n", - " lowfreq=20,\n", - " highfreq=max_freq,\n", - " preemph=0.97,\n", - " dither=0.0,\n", - " wintype='povey')\n", - "print(fbank_feat.shape, fbank_feat.dtype)\n", - "print(fbank_feat)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "numeric-analyst", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(184, 160)\n", - "[ 8.59522397 8.43148278 8.36414052 8.45487173 8.31761643 8.04843683\n", - " 8.01683696 7.6574614 7.95521932 8.22945157 10.20138275 9.0447775\n", - " 9.14763398 9.18184349 9.03801065 9.04852307 8.67706728 8.71894271\n", - " 9.54553655 9.19535135 8.76413076 8.47828946 8.52586143 8.49469288\n", - " 8.72461247 8.28562879 8.11581393 7.99922156 7.91023364 8.04142296\n", - " 7.89762773 7.76257636 8.32043745 8.01592886 8.34109665 8.90115454\n", - " 8.48246945 7.98658664 8.05745122 8.11384088 8.18864479 8.8091827\n", - " 11.8067711 13.25258218 14.44311795 13.90515283 14.00120623 13.99801252\n", - " 13.81595394 13.6379904 13.3574897 13.14933334 12.96518543 13.02601156\n", - " 12.70246737 12.54410834 12.15615068 11.86574681 11.67497882 10.79645481\n", - " 10.48150035 10.03758575 10.05637027 9.92891308 10.06923218 12.43382431\n", - " 12.71428321 14.33135052 13.94470959 14.29188291 14.11483993 14.03496606\n", - " 13.78167331 13.66701466 14.40308625 14.73934137 15.09569382 14.89565815\n", - " 15.10519995 14.94383582 15.03275563 15.42194679 15.29219967 15.41602274\n", - " 15.39242545 15.76836177 16.259222 16.47777231 17.03366795 17.46165793\n", - " 17.52596217 17.78844031 17.99878075 18.11446843 17.95761578 17.99900337\n", - " 17.86282737 17.7290163 17.47686504 17.43425516 17.07750485 16.64395242\n", - " 15.68217043 14.90058399 14.45645737 14.0405463 14.89549542 16.00405781\n", - " 16.27301689 16.37572895 16.31219037 16.31765447 16.44819716 16.36281089\n", - " 16.24932823 15.79302555 14.76361963 13.95761882 13.48917053 13.45543501\n", - " 13.00091327 13.13854248 13.74596395 13.86340629 14.00656109 13.77432101\n", - " 13.64267001 13.35742634 13.23042234 12.97916104 12.80694468 12.70005006\n", - " 13.2802483 13.22644525 13.14579624 13.02536594 13.36511022 11.37167205\n", - " 12.11598045 12.47619798 12.83885973 11.63880287 11.42083924 11.08747705\n", - " 11.04093403 11.11263149 10.74353319 10.58734669 10.46180738 10.34157335\n", - " 9.63131146 9.70582692 9.29059204 8.94583657 8.66065094 8.46799095\n", - " 8.25064103 8.30239167 8.19463371 8.12104567 8.02731234 8.06412715\n", - " 7.84889951 7.73090283 7.74119562 7.85444657 7.80717312 7.7129933\n", - " 7.84087442 7.77907788 7.60660865 7.55051479 7.458385 7.496416\n", - " 7.69519793 7.49086759 7.32199493 8.01617458 7.58525375 7.06661122\n", - " 6.94653756 7.19874283 7.28515661 7.17574078]\n", - "(184,)\n", - "(184,)\n", - "[1.48370471 1.52174523 1.46984238 1.67010478 1.88757689 1.68825992\n", - " 1.74270259 1.55497318 1.29200818 1.68446481 1.88133219 1.97138928\n", - " 2.15910096 2.3149476 1.9820247 2.07694378 1.93498835 2.01493974\n", - " 2.39156824 2.02396518 1.69586449 1.63808752 1.64020228 1.43573473\n", - " 1.93092656 1.37466294 1.34704929 1.59600739 1.03960441 1.45276496\n", - " 1.59360131 1.57466343 1.89491479 1.79333746 1.32701974 1.49441767\n", - " 1.51466756 1.63497989 1.42858074 1.51135396 1.61077201 1.81066387\n", - " 1.83367783 2.3507094 2.87885378 3.26231227 2.1313117 1.98557548\n", - " 1.99105426 2.26150533 2.34298751 2.44621608 2.39201042 2.41226503\n", - " 2.5142992 3.03777565 2.81592295 2.75117863 2.78324175 2.68819666\n", - " 2.8945782 2.84464168 2.680973 2.78397395 2.47996808 1.71829563\n", - " 1.60636949 1.65992483 1.38122631 1.74831825 2.16006884 1.68076185\n", - " 1.69329487 1.44929837 1.63763312 1.80101076 2.01166253 2.03254244\n", - " 1.9583913 2.04542255 2.00859694 2.16600883 2.16095629 1.97541122\n", - " 2.13807632 2.06386436 2.2154187 2.84205688 2.54862449 2.64321545\n", - " 2.6805773 2.52300146 2.53209001 2.54682059 2.4521937 2.43155532\n", - " 2.42571275 2.23421289 2.23164529 2.23597192 2.14215121 2.10406703\n", - " 2.07962874 1.88506161 1.80092372 1.61156092 1.77426835 1.98765563\n", - " 2.0356793 1.87964187 1.779513 1.87187681 1.76463632 1.70978684\n", - " 1.76471778 1.75604749 1.62792552 1.73929352 1.6887024 1.8677704\n", - " 2.17342368 2.08166072 2.14567453 2.15936953 2.18351006 2.41010388\n", - " 2.26101752 2.25468001 2.23739715 2.15395133 2.04547813 1.92038843\n", - " 1.85491264 1.91905927 2.16709365 1.99924152 2.1850471 2.55461622\n", - " 2.72476673 1.69682926 1.73249614 2.06992695 2.1210591 1.66854454\n", - " 1.63907505 1.32203822 1.38992558 1.2436937 1.17932877 1.02963653\n", - " 1.26085036 1.16997132 1.09339504 1.14188689 1.18675772 1.31859788\n", - " 1.21746591 1.3872131 1.26095274 1.34885761 1.46633543 1.64506975\n", - " 1.36013821 1.45574721 1.43766588 1.65119054 1.57163772 1.55082968\n", - " 1.29413316 1.38351736 1.64234673 1.57186432 1.45381083 1.71204761\n", - " 1.51828607 1.30639985 1.32928395 1.49004237 1.6057589 1.81815735\n", - " 1.67784678 1.72180861 1.60703743 1.64850255]\n" - ] - } - ], - "source": [ - "a = np.hstack([feat, feat])\n", - "print(a.shape)\n", - "m = np.mean(a, axis=1)\n", - "print(m)\n", - "print(m.shape)\n", - "std = np.std(a, axis=1)\n", - "print(std.shape)\n", - "print(std)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "nonprofit-potato", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "hispanic-ethics", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torchaudio\n", - "import torchaudio.compliance.kaldi as kaldi\n", - "import torchaudio.sox_effects as sox_effects\n", - "from torch.nn.utils.rnn import pad_sequence\n", - "torchaudio.set_audio_backend(\"sox\")" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "changing-calvin", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 29746])\n", - "tensor([[54., 90., 77., ..., 58., 58., 61.]])\n", - "(184, 80)\n", - "[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n", - " [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n", - " [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n", - " ...\n", - " [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n", - " [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n", - " [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n", - "-----------\n", - "[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n", - "(184, 80)\n", - "[[-10.177039 -10.717326 -15.46954 ... -10.546229 -11.897424 -12.987689]\n", - " [ -9.750411 -10.476343 -14.485752 ... -9.557108 -10.436023 -11.955799]\n", - " [-10.525113 -10.798049 -13.46475 ... -10.343097 -11.101464 -12.832712]\n", - " ...\n", - " [-10.649446 -10.907673 -14.056403 ... -10.578607 -11.790988 -12.038239]\n", - " [-10.816959 -11.114918 -12.88781 ... -10.570049 -11.199847 -13.101528]\n", - " [-14.320845 -13.03106 -13.036756 ... -10.829194 -11.171779 -12.634331]]\n", - "**************\n", - "[0.00164795 0.00274658 0.00234985 ... 0.00177002 0.00177002 0.00186157]\n", - "[54. 90. 77. ... 58. 58. 61.] float32\n", - "(184, 80)\n", - "[[10.617376 10.077089 5.3248763 ... 10.248186 8.896992 7.8067265]\n", - " [11.044004 10.318072 6.3086634 ... 11.237308 10.358393 8.838616 ]\n", - " [10.269302 9.9963665 7.3296647 ... 10.451319 9.692951 7.9617033]\n", - " ...\n", - " [10.14497 9.886743 6.738012 ... 10.215809 9.0034275 8.756177 ]\n", - " [ 9.977456 9.679498 7.9066052 ... 10.224365 9.594568 7.6928873]\n", - " [ 6.4735703 7.7633557 7.7576594 ... 9.965221 9.622637 8.160085 ]]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: UserWarning: torchaudio.backend.sox_backend.load_wav has been deprecated and will be removed from 0.9.0 release. Please use \"torchaudio.load\".\n", - " \"\"\"Entry point for launching an IPython kernel.\n" - ] - } - ], - "source": [ - "waveform, sample_rate = torchaudio.load_wav(wav)\n", - "print(waveform.shape)\n", - "print(waveform)\n", - "mat = kaldi.fbank(\n", - " waveform,\n", - " num_mel_bins=80,\n", - " frame_length=25,\n", - " frame_shift=10,\n", - " dither=0,\n", - " energy_floor=0.0,\n", - " sample_frequency=sample_rate\n", - " )\n", - "mat = mat.detach().numpy()\n", - "print(mat.shape)\n", - "print(mat)\n", - "\n", - "print('-----------')\n", - "print(samples._samples)\n", - "aud = torch.tensor(samples._samples).view(1, -1)\n", - "mat = kaldi.fbank(\n", - " aud,\n", - " num_mel_bins=80,\n", - " frame_length=25,\n", - " frame_shift=10,\n", - " dither=0,\n", - " energy_floor=0.0,\n", - " sample_frequency=sample_rate\n", - " )\n", - "mat = mat.detach().numpy()\n", - "print(mat.shape)\n", - "print(mat)\n", - "\n", - "print('**************')\n", - "print(samples._samples)\n", - "tmp = samples.to('int16').astype('float32')\n", - "print(tmp, tmp.dtype)\n", - "aud = torch.tensor(tmp).view(1, -1)\n", - "mat = kaldi.fbank(\n", - " aud,\n", - " num_mel_bins=80,\n", - " frame_length=25,\n", - " frame_shift=10,\n", - " dither=0,\n", - " energy_floor=0.0,\n", - " sample_frequency=sample_rate\n", - " )\n", - "mat = mat.detach().numpy()\n", - "print(mat.shape)\n", - "print(mat)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "buried-dependence", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "silver-printing", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "outer-space", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(29746,)\n", - "[54 90 77 ... 58 58 61]\n", - "(184, 80)\n", - "[[10.61737914 10.07708936 5.32487528 ... 10.2481839 8.89699394\n", - " 7.80671114]\n", - " [11.0440077 10.3180721 6.30866128 ... 11.23730926 10.35838868\n", - " 8.83860079]\n", - " [10.26930555 9.99636567 7.3296638 ... 10.45131595 9.69295303\n", - " 7.96168491]\n", - " ...\n", - " [10.14497345 9.88674207 6.73801138 ... 10.21580627 9.00343472\n", - " 8.75616521]\n", - " [ 9.97745961 9.67949736 7.90660425 ... 10.22436653 9.59456493\n", - " 7.69287184]\n", - " [ 6.47357374 7.76335491 7.75765843 ... 9.96522077 9.6226365\n", - " 8.16007108]]\n", - "(184, 13)\n", - "[[ 14.73775998 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n", - " 8.86862748]\n", - " [ 15.31274834 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n", - " 6.78353115]\n", - " [ 13.82218765 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n", - " -0.05919222]\n", - " ...\n", - " [ 13.5837844 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n", - " 5.59045842]\n", - " [ 13.75757034 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n", - " 12.0785146 ]\n", - " [ 11.92813809 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n", - " 0.88337965]]\n" - ] - } - ], - "source": [ - "from python_speech_features import mfcc\n", - "from python_speech_features import delta\n", - "from python_speech_features import logfbank\n", - "import scipy.io.wavfile as iowav\n", - "\n", - "(rate,sig) = iowav.read(wav)\n", - "print(sig.shape)\n", - "print(sig)\n", - "\n", - "# note that generally nfilt=40 is used for speech recognition\n", - "fbank_feat = logfbank(sig,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n", - "print(fbank_feat.shape)\n", - "print(fbank_feat)\n", - "\n", - "# the computed fbank coefficents of english.wav with dimension [110,23]\n", - "# [ 12.2865\t12.6906\t13.1765\t15.714\t16.064\t15.7553\t16.5746\t16.9205\t16.6472\t16.1302\t16.4576\t16.7326\t16.8864\t17.7215\t18.88\t19.1377\t19.1495\t18.6683\t18.3886\t20.3506\t20.2772\t18.8248\t18.1899\n", - "# 11.9198\t13.146\t14.7215\t15.8642\t17.4288\t16.394\t16.8238\t16.1095\t16.4297\t16.6331\t16.3163\t16.5093\t17.4981\t18.3429\t19.6555\t19.6263\t19.8435\t19.0534\t19.001\t20.0287\t19.7707\t19.5852\t19.1112\n", - "# ...\n", - "# ...\n", - "# the same with that using kaldi commands: compute-fbank-feats --dither=0.0\n", - "\n", - "mfcc_feat = mfcc(sig,dither=0,useEnergy=True,wintype='povey')\n", - "print(mfcc_feat.shape)\n", - "print(mfcc_feat)\n", - "\n", - "# the computed mfcc coefficents of english.wav with dimension [110,13]\n", - "# [ 17.1337\t-23.3651\t-7.41751\t-7.73686\t-21.3682\t-8.93884\t-3.70843\t4.68346\t-16.0676\t12.782\t-7.24054\t8.25089\t10.7292\n", - "# 17.1692\t-23.3028\t-5.61872\t-4.0075\t-23.287\t-20.6101\t-5.51584\t-6.15273\t-14.4333\t8.13052\t-0.0345329\t2.06274\t-0.564298\n", - "# ...\n", - "# ...\n", - "# the same with that using kaldi commands: compute-mfcc-feats --dither=0.0" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "sporting-school", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(184, 80)\n", - "[[-10.17703627 -10.71732606 -15.46954014 ... -10.54623152 -11.89742148\n", - " -12.98770428]\n", - " [ -9.75040771 -10.47634331 -14.48575413 ... -9.55710616 -10.43602673\n", - " -11.95581463]\n", - " [-10.52510987 -10.79804975 -13.46475161 ... -10.34309947 -11.10146239\n", - " -12.83273051]\n", - " ...\n", - " [-10.64944197 -10.90767335 -14.05640404 ... -10.57860915 -11.7909807\n", - " -12.03825021]\n", - " [-10.8169558 -11.11491806 -12.88781116 ... -10.57004889 -11.19985048\n", - " -13.10154358]\n", - " [-14.32084168 -13.03106051 -13.03675699 ... -10.82919465 -11.17177892\n", - " -12.63434434]]\n", - "(184, 13)\n", - "[[ -6.05665544 -13.30393391 5.85974818 ... -3.42359739 2.82785335\n", - " 8.86862748]\n", - " [ -5.48166707 -13.33671651 4.06537223 ... 8.15970347 2.15934846\n", - " 6.78353115]\n", - " [ -6.97222776 -13.39296404 6.8304843 ... 2.55332563 8.86724453\n", - " -0.05919222]\n", - " ...\n", - " [ -7.21063102 -13.42104892 11.21222354 ... 4.81477718 1.66627505\n", - " 5.59045842]\n", - " [ -7.03684508 -13.92626662 13.06074011 ... -0.46694046 5.56214833\n", - " 12.0785146 ]\n", - " [ -8.86627732 -15.9169855 8.78372271 ... -1.42014277 -3.25768086\n", - " 0.88337965]]\n" - ] - } - ], - "source": [ - "fbank_feat = logfbank(samples._samples,nfilt=80,lowfreq=20,dither=0,wintype='povey')\n", - "print(fbank_feat.shape)\n", - "print(fbank_feat)\n", - "\n", - "mfcc_feat = mfcc(samples._samples,dither=0,useEnergy=True,wintype='povey')\n", - "print(mfcc_feat.shape)\n", - "print(mfcc_feat)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "restricted-license", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "specialized-threat", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/espnet_dataloader.ipynb b/.notebook/espnet_dataloader.ipynb deleted file mode 100644 index 1bfc13e3c..000000000 --- a/.notebook/espnet_dataloader.ipynb +++ /dev/null @@ -1,1541 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 147, - "id": "extensive-venice", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/\n" - ] - }, - { - "data": { - "text/plain": [ - "'/'" - ] - }, - "execution_count": 147, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 148, - "id": "correct-window", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "manifest.dev\t manifest.test-clean\t manifest.train\r\n", - "manifest.dev.raw manifest.test-clean.raw manifest.train.raw\r\n" - ] - } - ], - "source": [ - "!ls /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/" - ] - }, - { - "cell_type": "code", - "execution_count": 149, - "id": "exceptional-cheese", - "metadata": {}, - "outputs": [], - "source": [ - "dev_data='/workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev'" - ] - }, - { - "cell_type": "code", - "execution_count": 150, - "id": "extraordinary-orleans", - "metadata": {}, - "outputs": [], - "source": [ - "from deepspeech.frontend.utility import read_manifest" - ] - }, - { - "cell_type": "code", - "execution_count": 151, - "id": "returning-lighter", - "metadata": {}, - "outputs": [], - "source": [ - "dev_json = read_manifest(dev_data)" - ] - }, - { - "cell_type": "code", - "execution_count": 152, - "id": "western-founder", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'input': [{'feat': '/workspace/zhanghui/asr/espnet/egs/librispeech/asr1/dump/dev/deltafalse/feats.1.ark:16',\n", - " 'name': 'input1',\n", - " 'shape': [1063, 83]}],\n", - " 'output': [{'name': 'target1',\n", - " 'shape': [41, 5002],\n", - " 'text': 'AS I APPROACHED THE CITY I HEARD BELLS RINGING AND A '\n", - " 'LITTLE LATER I FOUND THE STREETS ASTIR WITH THRONGS OF '\n", - " 'WELL DRESSED PEOPLE IN FAMILY GROUPS WENDING THEIR WAY '\n", - " 'HITHER AND THITHER',\n", - " 'token': '▁AS ▁I ▁APPROACHED ▁THE ▁CITY ▁I ▁HEARD ▁BELL S ▁RING '\n", - " 'ING ▁AND ▁A ▁LITTLE ▁LATER ▁I ▁FOUND ▁THE ▁STREETS ▁AS '\n", - " 'T IR ▁WITH ▁THRONG S ▁OF ▁WELL ▁DRESSED ▁PEOPLE ▁IN '\n", - " '▁FAMILY ▁GROUP S ▁WE ND ING ▁THEIR ▁WAY ▁HITHER ▁AND '\n", - " '▁THITHER',\n", - " 'tokenid': '713 2458 676 4502 1155 2458 2351 849 389 3831 206 627 '\n", - " '482 2812 2728 2458 2104 4502 4316 713 404 212 4925 '\n", - " '4549 389 3204 4861 1677 3339 2495 1950 2279 389 4845 '\n", - " '302 206 4504 4843 2394 627 4526'}],\n", - " 'utt': '116-288045-0000',\n", - " 'utt2spk': '116-288045'}\n", - "5542\n", - "\n" - ] - } - ], - "source": [ - "from pprint import pprint\n", - "pprint(dev_json[0])\n", - "print(len(dev_json))\n", - "print(type(dev_json))" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "id": "motivated-receptor", - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "import itertools\n", - "\n", - "import numpy as np\n", - "\n", - "from deepspeech.utils.log import Log\n", - "\n", - "__all__ = [\"make_batchset\"]\n", - "\n", - "logger = Log(__name__).getlog()\n", - "\n", - "\n", - "def batchfy_by_seq(\n", - " sorted_data,\n", - " batch_size,\n", - " max_length_in,\n", - " max_length_out,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " ikey=\"input\",\n", - " iaxis=0,\n", - " okey=\"output\",\n", - " oaxis=0, ):\n", - " \"\"\"Make batch set from json dictionary\n", - "\n", - " :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n", - " :param int batch_size: batch size\n", - " :param int max_length_in: maximum length of input to decide adaptive batch size\n", - " :param int max_length_out: maximum length of output to decide adaptive batch size\n", - " :param int min_batch_size: mininum batch size (for multi-gpu)\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - " :param str ikey: key to access input\n", - " (for ASR ikey=\"input\", for TTS, MT ikey=\"output\".)\n", - " :param int iaxis: dimension to access input\n", - " (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n", - " :param str okey: key to access output\n", - " (for ASR, MT okey=\"output\". for TTS okey=\"input\".)\n", - " :param int oaxis: dimension to access output\n", - " (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)\n", - " :return: List[List[Tuple[str, dict]]] list of batches\n", - " \"\"\"\n", - " if batch_size <= 0:\n", - " raise ValueError(f\"Invalid batch_size={batch_size}\")\n", - "\n", - " # check #utts is more than min_batch_size\n", - " if len(sorted_data) < min_batch_size:\n", - " raise ValueError(\n", - " f\"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size}).\"\n", - " )\n", - "\n", - " # make list of minibatches\n", - " minibatches = []\n", - " start = 0\n", - " while True:\n", - " _, info = sorted_data[start]\n", - " ilen = int(info[ikey][iaxis][\"shape\"][0])\n", - " olen = (int(info[okey][oaxis][\"shape\"][0]) if oaxis >= 0 else\n", - " max(map(lambda x: int(x[\"shape\"][0]), info[okey])))\n", - " factor = max(int(ilen / max_length_in), int(olen / max_length_out))\n", - " # change batchsize depending on the input and output length\n", - " # if ilen = 1000 and max_length_in = 800\n", - " # then b = batchsize / 2\n", - " # and max(min_batches, .) avoids batchsize = 0\n", - " bs = max(min_batch_size, int(batch_size / (1 + factor)))\n", - " end = min(len(sorted_data), start + bs)\n", - " minibatch = sorted_data[start:end]\n", - " if shortest_first:\n", - " minibatch.reverse()\n", - "\n", - " # check each batch is more than minimum batchsize\n", - " if len(minibatch) < min_batch_size:\n", - " mod = min_batch_size - len(minibatch) % min_batch_size\n", - " additional_minibatch = [\n", - " sorted_data[i] for i in np.random.randint(0, start, mod)\n", - " ]\n", - " if shortest_first:\n", - " additional_minibatch.reverse()\n", - " minibatch.extend(additional_minibatch)\n", - " minibatches.append(minibatch)\n", - "\n", - " if end == len(sorted_data):\n", - " break\n", - " start = end\n", - "\n", - " # batch: List[List[Tuple[str, dict]]]\n", - " return minibatches\n", - "\n", - "\n", - "def batchfy_by_bin(\n", - " sorted_data,\n", - " batch_bins,\n", - " num_batches=0,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " ikey=\"input\",\n", - " okey=\"output\", ):\n", - " \"\"\"Make variably sized batch set, which maximizes\n", - "\n", - " the number of bins up to `batch_bins`.\n", - "\n", - " :param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json\n", - " :param int batch_bins: Maximum frames of a batch\n", - " :param int num_batches: # number of batches to use (for debug)\n", - " :param int min_batch_size: minimum batch size (for multi-gpu)\n", - " :param int test: Return only every `test` batches\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - "\n", - " :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n", - " :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n", - "\n", - " :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n", - " \"\"\"\n", - " if batch_bins <= 0:\n", - " raise ValueError(f\"invalid batch_bins={batch_bins}\")\n", - " length = len(sorted_data)\n", - " idim = int(sorted_data[0][1][ikey][0][\"shape\"][1])\n", - " odim = int(sorted_data[0][1][okey][0][\"shape\"][1])\n", - " logger.info(\"# utts: \" + str(len(sorted_data)))\n", - " minibatches = []\n", - " start = 0\n", - " n = 0\n", - " while True:\n", - " # Dynamic batch size depending on size of samples\n", - " b = 0\n", - " next_size = 0\n", - " max_olen = 0\n", - " while next_size < batch_bins and (start + b) < length:\n", - " ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0]) * idim\n", - " olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0]) * odim\n", - " if olen > max_olen:\n", - " max_olen = olen\n", - " next_size = (max_olen + ilen) * (b + 1)\n", - " if next_size <= batch_bins:\n", - " b += 1\n", - " elif next_size == 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in batch_bins ({batch_bins}): \"\n", - " f\"Please increase the value\")\n", - " end = min(length, start + max(min_batch_size, b))\n", - " batch = sorted_data[start:end]\n", - " if shortest_first:\n", - " batch.reverse()\n", - " minibatches.append(batch)\n", - " # Check for min_batch_size and fixes the batches if needed\n", - " i = -1\n", - " while len(minibatches[i]) < min_batch_size:\n", - " missing = min_batch_size - len(minibatches[i])\n", - " if -i == len(minibatches):\n", - " minibatches[i + 1].extend(minibatches[i])\n", - " minibatches = minibatches[1:]\n", - " break\n", - " else:\n", - " minibatches[i].extend(minibatches[i - 1][:missing])\n", - " minibatches[i - 1] = minibatches[i - 1][missing:]\n", - " i -= 1\n", - " if end == length:\n", - " break\n", - " start = end\n", - " n += 1\n", - " if num_batches > 0:\n", - " minibatches = minibatches[:num_batches]\n", - " lengths = [len(x) for x in minibatches]\n", - " logger.info(\n", - " str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n", - " + \" to \" + str(max(lengths)) + \" samples \" + \"(avg \" + str(\n", - " int(np.mean(lengths))) + \" samples).\")\n", - " return minibatches\n", - "\n", - "\n", - "def batchfy_by_frame(\n", - " sorted_data,\n", - " max_frames_in,\n", - " max_frames_out,\n", - " max_frames_inout,\n", - " num_batches=0,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " ikey=\"input\",\n", - " okey=\"output\", ):\n", - " \"\"\"Make variable batch set, which maximizes the number of frames to max_batch_frame.\n", - "\n", - " :param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json\n", - " :param int max_frames_in: Maximum input frames of a batch\n", - " :param int max_frames_out: Maximum output frames of a batch\n", - " :param int max_frames_inout: Maximum input+output frames of a batch\n", - " :param int num_batches: # number of batches to use (for debug)\n", - " :param int min_batch_size: minimum batch size (for multi-gpu)\n", - " :param int test: Return only every `test` batches\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - "\n", - " :param str ikey: key to access input (for ASR ikey=\"input\", for TTS ikey=\"output\".)\n", - " :param str okey: key to access output (for ASR okey=\"output\". for TTS okey=\"input\".)\n", - "\n", - " :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches\n", - " \"\"\"\n", - " if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:\n", - " raise ValueError(\n", - " \"At least, one of `--batch-frames-in`, `--batch-frames-out` or \"\n", - " \"`--batch-frames-inout` should be > 0\")\n", - " length = len(sorted_data)\n", - " minibatches = []\n", - " start = 0\n", - " end = 0\n", - " while end != length:\n", - " # Dynamic batch size depending on size of samples\n", - " b = 0\n", - " max_olen = 0\n", - " max_ilen = 0\n", - " while (start + b) < length:\n", - " ilen = int(sorted_data[start + b][1][ikey][0][\"shape\"][0])\n", - " if ilen > max_frames_in and max_frames_in != 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in --batch-frames-in ({max_frames_in}): \"\n", - " f\"Please increase the value\")\n", - " olen = int(sorted_data[start + b][1][okey][0][\"shape\"][0])\n", - " if olen > max_frames_out and max_frames_out != 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in --batch-frames-out ({max_frames_out}): \"\n", - " f\"Please increase the value\")\n", - " if ilen + olen > max_frames_inout and max_frames_inout != 0:\n", - " raise ValueError(\n", - " f\"Can't fit one sample in --batch-frames-out ({max_frames_inout}): \"\n", - " f\"Please increase the value\")\n", - " max_olen = max(max_olen, olen)\n", - " max_ilen = max(max_ilen, ilen)\n", - " in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0\n", - " out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0\n", - " inout_ok = (max_ilen + max_olen) * (\n", - " b + 1) <= max_frames_inout or max_frames_inout == 0\n", - " if in_ok and out_ok and inout_ok:\n", - " # add more seq in the minibatch\n", - " b += 1\n", - " else:\n", - " # no more seq in the minibatch\n", - " break\n", - " end = min(length, start + b)\n", - " batch = sorted_data[start:end]\n", - " if shortest_first:\n", - " batch.reverse()\n", - " minibatches.append(batch)\n", - " # Check for min_batch_size and fixes the batches if needed\n", - " i = -1\n", - " while len(minibatches[i]) < min_batch_size:\n", - " missing = min_batch_size - len(minibatches[i])\n", - " if -i == len(minibatches):\n", - " minibatches[i + 1].extend(minibatches[i])\n", - " minibatches = minibatches[1:]\n", - " break\n", - " else:\n", - " minibatches[i].extend(minibatches[i - 1][:missing])\n", - " minibatches[i - 1] = minibatches[i - 1][missing:]\n", - " i -= 1\n", - " start = end\n", - " if num_batches > 0:\n", - " minibatches = minibatches[:num_batches]\n", - " lengths = [len(x) for x in minibatches]\n", - " logger.info(\n", - " str(len(minibatches)) + \" batches containing from \" + str(min(lengths))\n", - " + \" to \" + str(max(lengths)) + \" samples\" + \"(avg \" + str(\n", - " int(np.mean(lengths))) + \" samples).\")\n", - "\n", - " return minibatches\n", - "\n", - "\n", - "def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,\n", - " shortest_first):\n", - " import random\n", - "\n", - " logger.info(\"use shuffled batch.\")\n", - " sorted_data = random.sample(data.items(), len(data.items()))\n", - " logger.info(\"# utts: \" + str(len(sorted_data)))\n", - " # make list of minibatches\n", - " minibatches = []\n", - " start = 0\n", - " while True:\n", - " end = min(len(sorted_data), start + batch_size)\n", - " # check each batch is more than minimum batchsize\n", - " minibatch = sorted_data[start:end]\n", - " if shortest_first:\n", - " minibatch.reverse()\n", - " if len(minibatch) < min_batch_size:\n", - " mod = min_batch_size - len(minibatch) % min_batch_size\n", - " additional_minibatch = [\n", - " sorted_data[i] for i in np.random.randint(0, start, mod)\n", - " ]\n", - " if shortest_first:\n", - " additional_minibatch.reverse()\n", - " minibatch.extend(additional_minibatch)\n", - " minibatches.append(minibatch)\n", - " if end == len(sorted_data):\n", - " break\n", - " start = end\n", - "\n", - " # for debugging\n", - " if num_batches > 0:\n", - " minibatches = minibatches[:num_batches]\n", - " logger.info(\"# minibatches: \" + str(len(minibatches)))\n", - " return minibatches\n", - "\n", - "\n", - "BATCH_COUNT_CHOICES = [\"auto\", \"seq\", \"bin\", \"frame\"]\n", - "BATCH_SORT_KEY_CHOICES = [\"input\", \"output\", \"shuffle\"]\n", - "\n", - "\n", - "def make_batchset(\n", - " data,\n", - " batch_size=0,\n", - " max_length_in=float(\"inf\"),\n", - " max_length_out=float(\"inf\"),\n", - " num_batches=0,\n", - " min_batch_size=1,\n", - " shortest_first=False,\n", - " batch_sort_key=\"input\",\n", - " count=\"auto\",\n", - " batch_bins=0,\n", - " batch_frames_in=0,\n", - " batch_frames_out=0,\n", - " batch_frames_inout=0,\n", - " iaxis=0,\n", - " oaxis=0, ):\n", - " \"\"\"Make batch set from json dictionary\n", - "\n", - " if utts have \"category\" value,\n", - "\n", - " >>> data = {'utt1': {'category': 'A', 'input': ...},\n", - " ... 'utt2': {'category': 'B', 'input': ...},\n", - " ... 'utt3': {'category': 'B', 'input': ...},\n", - " ... 'utt4': {'category': 'A', 'input': ...}}\n", - " >>> make_batchset(data, batchsize=2, ...)\n", - " [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]\n", - "\n", - " Note that if any utts doesn't have \"category\",\n", - " perform as same as batchfy_by_{count}\n", - "\n", - " :param List[Dict[str, Any]] data: dictionary loaded from data.json\n", - " :param int batch_size: maximum number of sequences in a minibatch.\n", - " :param int batch_bins: maximum number of bins (frames x dim) in a minibatch.\n", - " :param int batch_frames_in: maximum number of input frames in a minibatch.\n", - " :param int batch_frames_out: maximum number of output frames in a minibatch.\n", - " :param int batch_frames_out: maximum number of input+output frames in a minibatch.\n", - " :param str count: strategy to count maximum size of batch.\n", - " For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES\n", - "\n", - " :param int max_length_in: maximum length of input to decide adaptive batch size\n", - " :param int max_length_out: maximum length of output to decide adaptive batch size\n", - " :param int num_batches: # number of batches to use (for debug)\n", - " :param int min_batch_size: minimum batch size (for multi-gpu)\n", - " :param bool shortest_first: Sort from batch with shortest samples\n", - " to longest if true, otherwise reverse\n", - " :param str batch_sort_key: how to sort data before creating minibatches\n", - " [\"input\", \"output\", \"shuffle\"]\n", - " :param bool swap_io: if True, use \"input\" as output and \"output\"\n", - " as input in `data` dict\n", - " :param bool mt: if True, use 0-axis of \"output\" as output and 1-axis of \"output\"\n", - " as input in `data` dict\n", - " :param int iaxis: dimension to access input\n", - " (for ASR, TTS iaxis=0, for MT iaxis=\"1\".)\n", - " :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,\n", - " reserved for future research, -1 means all axis.)\n", - " :return: List[List[Tuple[str, dict]]] list of batches\n", - " \"\"\"\n", - "\n", - " # check args\n", - " if count not in BATCH_COUNT_CHOICES:\n", - " raise ValueError(\n", - " f\"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}\")\n", - " if batch_sort_key not in BATCH_SORT_KEY_CHOICES:\n", - " raise ValueError(f\"arg 'batch_sort_key' ({batch_sort_key}) should be \"\n", - " f\"one of {BATCH_SORT_KEY_CHOICES}\")\n", - "\n", - " ikey = \"input\"\n", - " okey = \"output\"\n", - " batch_sort_axis = 0 # index of list \n", - "\n", - " if count == \"auto\":\n", - " if batch_size != 0:\n", - " count = \"seq\"\n", - " elif batch_bins != 0:\n", - " count = \"bin\"\n", - " elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:\n", - " count = \"frame\"\n", - " else:\n", - " raise ValueError(\n", - " f\"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}\"\n", - " )\n", - " logger.info(f\"count is auto detected as {count}\")\n", - "\n", - " if count != \"seq\" and batch_sort_key == \"shuffle\":\n", - " raise ValueError(\n", - " \"batch_sort_key=shuffle is only available if batch_count=seq\")\n", - "\n", - " category2data = {} # Dict[str, dict]\n", - " for v in data:\n", - " k = v['utt']\n", - " category2data.setdefault(v.get(\"category\"), {})[k] = v\n", - "\n", - " batches_list = [] # List[List[List[Tuple[str, dict]]]]\n", - " for d in category2data.values():\n", - " if batch_sort_key == \"shuffle\":\n", - " batches = batchfy_shuffle(d, batch_size, min_batch_size,\n", - " num_batches, shortest_first)\n", - " batches_list.append(batches)\n", - " continue\n", - "\n", - " # sort it by input lengths (long to short)\n", - " sorted_data = sorted(\n", - " d.items(),\n", - " key=lambda data: int(data[1][batch_sort_key][batch_sort_axis][\"shape\"][0]),\n", - " reverse=not shortest_first, )\n", - " logger.info(\"# utts: \" + str(len(sorted_data)))\n", - " \n", - " if count == \"seq\":\n", - " batches = batchfy_by_seq(\n", - " sorted_data,\n", - " batch_size=batch_size,\n", - " max_length_in=max_length_in,\n", - " max_length_out=max_length_out,\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=shortest_first,\n", - " ikey=ikey,\n", - " iaxis=iaxis,\n", - " okey=okey,\n", - " oaxis=oaxis, )\n", - " if count == \"bin\":\n", - " batches = batchfy_by_bin(\n", - " sorted_data,\n", - " batch_bins=batch_bins,\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=shortest_first,\n", - " ikey=ikey,\n", - " okey=okey, )\n", - " if count == \"frame\":\n", - " batches = batchfy_by_frame(\n", - " sorted_data,\n", - " max_frames_in=batch_frames_in,\n", - " max_frames_out=batch_frames_out,\n", - " max_frames_inout=batch_frames_inout,\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=shortest_first,\n", - " ikey=ikey,\n", - " okey=okey, )\n", - " batches_list.append(batches)\n", - "\n", - " if len(batches_list) == 1:\n", - " batches = batches_list[0]\n", - " else:\n", - " # Concat list. This way is faster than \"sum(batch_list, [])\"\n", - " batches = list(itertools.chain(*batches_list))\n", - "\n", - " # for debugging\n", - " if num_batches > 0:\n", - " batches = batches[:num_batches]\n", - " logger.info(\"# minibatches: \" + str(len(batches)))\n", - "\n", - " # batch: List[List[Tuple[str, dict]]]\n", - " return batches\n" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "id": "acquired-hurricane", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[INFO 2021/08/18 06:57:10 1445365138.py:284] use shuffled batch.\n", - "[INFO 2021/08/18 06:57:10 1445365138.py:286] # utts: 5542\n", - "[INFO 2021/08/18 06:57:10 1445365138.py:468] # minibatches: 555\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "555\n" - ] - } - ], - "source": [ - "batch_size=10\n", - "maxlen_in=300\n", - "maxlen_out=400\n", - "minibatches=0 # for debug\n", - "min_batch_size=2\n", - "use_sortagrad=True\n", - "batch_count='seq'\n", - "batch_bins=0\n", - "batch_frames_in=3000\n", - "batch_frames_out=0\n", - "batch_frames_inout=0\n", - " \n", - "dev_data = make_batchset(\n", - " dev_json,\n", - " batch_size,\n", - " maxlen_in,\n", - " maxlen_out,\n", - " minibatches, # for debug\n", - " min_batch_size=min_batch_size,\n", - " shortest_first=use_sortagrad,\n", - " batch_sort_key=\"shuffle\",\n", - " count=batch_count,\n", - " batch_bins=batch_bins,\n", - " batch_frames_in=batch_frames_in,\n", - " batch_frames_out=batch_frames_out,\n", - " batch_frames_inout=batch_frames_inout,\n", - " iaxis=0,\n", - " oaxis=0, )\n", - "print(len(dev_data))\n", - "# for i in range(len(dev_data)):\n", - "# print(len(dev_data[i]))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "id": "warming-malpractice", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: kaldiio in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (2.17.2)\n", - "Requirement already satisfied: numpy in ./DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numpy-1.21.2-py3.7-linux-x86_64.egg (from kaldiio) (1.21.2)\n", - "\u001b[33mWARNING: You are using pip version 20.3.3; however, version 21.2.4 is available.\n", - "You should consider upgrading via the '/workspace/zhanghui/DeepSpeech-2.x/tools/venv/bin/python -m pip install --upgrade pip' command.\u001b[0m\n" - ] - } - ], - "source": [ - "!pip install kaldiio" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "equipped-subject", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 100, - "id": "superb-methodology", - "metadata": {}, - "outputs": [], - "source": [ - "from collections import OrderedDict\n", - "import kaldiio\n", - "\n", - "class LoadInputsAndTargets():\n", - " \"\"\"Create a mini-batch from a list of dicts\n", - "\n", - " >>> batch = [('utt1',\n", - " ... dict(input=[dict(feat='some.ark:123',\n", - " ... filetype='mat',\n", - " ... name='input1',\n", - " ... shape=[100, 80])],\n", - " ... output=[dict(tokenid='1 2 3 4',\n", - " ... name='target1',\n", - " ... shape=[4, 31])]]))\n", - " >>> l = LoadInputsAndTargets()\n", - " >>> feat, target = l(batch)\n", - "\n", - " :param: str mode: Specify the task mode, \"asr\" or \"tts\"\n", - " :param: str preprocess_conf: The path of a json file for pre-processing\n", - " :param: bool load_input: If False, not to load the input data\n", - " :param: bool load_output: If False, not to load the output data\n", - " :param: bool sort_in_input_length: Sort the mini-batch in descending order\n", - " of the input length\n", - " :param: bool use_speaker_embedding: Used for tts mode only\n", - " :param: bool use_second_target: Used for tts mode only\n", - " :param: dict preprocess_args: Set some optional arguments for preprocessing\n", - " :param: Optional[dict] preprocess_args: Used for tts mode only\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " mode=\"asr\",\n", - " preprocess_conf=None,\n", - " load_input=True,\n", - " load_output=True,\n", - " sort_in_input_length=True,\n", - " preprocess_args=None,\n", - " keep_all_data_on_mem=False, ):\n", - " self._loaders = {}\n", - "\n", - " if mode not in [\"asr\"]:\n", - " raise ValueError(\"Only asr are allowed: mode={}\".format(mode))\n", - "\n", - " if preprocess_conf is not None:\n", - " self.preprocessing = AugmentationPipeline(preprocess_conf)\n", - " logging.warning(\n", - " \"[Experimental feature] Some preprocessing will be done \"\n", - " \"for the mini-batch creation using {}\".format(\n", - " self.preprocessing))\n", - " else:\n", - " # If conf doesn't exist, this function don't touch anything.\n", - " self.preprocessing = None\n", - "\n", - " self.mode = mode\n", - " self.load_output = load_output\n", - " self.load_input = load_input\n", - " self.sort_in_input_length = sort_in_input_length\n", - " if preprocess_args is None:\n", - " self.preprocess_args = {}\n", - " else:\n", - " assert isinstance(preprocess_args, dict), type(preprocess_args)\n", - " self.preprocess_args = dict(preprocess_args)\n", - "\n", - " self.keep_all_data_on_mem = keep_all_data_on_mem\n", - "\n", - " def __call__(self, batch, return_uttid=False):\n", - " \"\"\"Function to load inputs and targets from list of dicts\n", - "\n", - " :param List[Tuple[str, dict]] batch: list of dict which is subset of\n", - " loaded data.json\n", - " :param bool return_uttid: return utterance ID information for visualization\n", - " :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]\n", - " :return: list of input feature sequences\n", - " [(T_1, D), (T_2, D), ..., (T_B, D)]\n", - " :rtype: list of float ndarray\n", - " :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]\n", - " :rtype: list of int ndarray\n", - "\n", - " \"\"\"\n", - " x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n", - " y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]\n", - " uttid_list = [] # List[str]\n", - "\n", - " for uttid, info in batch:\n", - " uttid_list.append(uttid)\n", - "\n", - " if self.load_input:\n", - " # Note(kamo): This for-loop is for multiple inputs\n", - " for idx, inp in enumerate(info[\"input\"]):\n", - " # {\"input\":\n", - " # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"hdf5\",\n", - " # \"name\": \"input1\", ...}], ...}\n", - " x = self._get_from_loader(\n", - " filepath=inp[\"feat\"],\n", - " filetype=inp.get(\"filetype\", \"mat\"))\n", - " x_feats_dict.setdefault(inp[\"name\"], []).append(x)\n", - "\n", - " if self.load_output:\n", - " for idx, inp in enumerate(info[\"output\"]):\n", - " if \"tokenid\" in inp:\n", - " # ======= Legacy format for output =======\n", - " # {\"output\": [{\"tokenid\": \"1 2 3 4\"}])\n", - " x = np.fromiter(\n", - " map(int, inp[\"tokenid\"].split()), dtype=np.int64)\n", - " else:\n", - " # ======= New format =======\n", - " # {\"input\":\n", - " # [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"hdf5\",\n", - " # \"name\": \"target1\", ...}], ...}\n", - " x = self._get_from_loader(\n", - " filepath=inp[\"feat\"],\n", - " filetype=inp.get(\"filetype\", \"mat\"))\n", - "\n", - " y_feats_dict.setdefault(inp[\"name\"], []).append(x)\n", - "\n", - " if self.mode == \"asr\":\n", - " return_batch, uttid_list = self._create_batch_asr(\n", - " x_feats_dict, y_feats_dict, uttid_list)\n", - " else:\n", - " raise NotImplementedError(self.mode)\n", - "\n", - " if self.preprocessing is not None:\n", - " # Apply pre-processing all input features\n", - " for x_name in return_batch.keys():\n", - " if x_name.startswith(\"input\"):\n", - " return_batch[x_name] = self.preprocessing(\n", - " return_batch[x_name], uttid_list,\n", - " **self.preprocess_args)\n", - "\n", - " if return_uttid:\n", - " return tuple(return_batch.values()), uttid_list\n", - "\n", - " # Doesn't return the names now.\n", - " return tuple(return_batch.values())\n", - "\n", - " def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):\n", - " \"\"\"Create a OrderedDict for the mini-batch\n", - "\n", - " :param OrderedDict x_feats_dict:\n", - " e.g. {\"input1\": [ndarray, ndarray, ...],\n", - " \"input2\": [ndarray, ndarray, ...]}\n", - " :param OrderedDict y_feats_dict:\n", - " e.g. {\"target1\": [ndarray, ndarray, ...],\n", - " \"target2\": [ndarray, ndarray, ...]}\n", - " :param: List[str] uttid_list:\n", - " Give uttid_list to sort in the same order as the mini-batch\n", - " :return: batch, uttid_list\n", - " :rtype: Tuple[OrderedDict, List[str]]\n", - " \"\"\"\n", - " # handle single-input and multi-input (paralell) asr mode\n", - " xs = list(x_feats_dict.values())\n", - "\n", - " if self.load_output:\n", - " ys = list(y_feats_dict.values())\n", - " assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))\n", - "\n", - " # get index of non-zero length samples\n", - " nonzero_idx = list(\n", - " filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))\n", - " for n in range(1, len(y_feats_dict)):\n", - " nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)\n", - " else:\n", - " # Note(kamo): Be careful not to make nonzero_idx to a generator\n", - " nonzero_idx = list(range(len(xs[0])))\n", - "\n", - " if self.sort_in_input_length:\n", - " # sort in input lengths based on the first input\n", - " nonzero_sorted_idx = sorted(\n", - " nonzero_idx, key=lambda i: -len(xs[0][i]))\n", - " else:\n", - " nonzero_sorted_idx = nonzero_idx\n", - "\n", - " if len(nonzero_sorted_idx) != len(xs[0]):\n", - " logging.warning(\n", - " \"Target sequences include empty tokenid (batch {} -> {}).\".\n", - " format(len(xs[0]), len(nonzero_sorted_idx)))\n", - "\n", - " # remove zero-length samples\n", - " xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]\n", - " uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]\n", - "\n", - " x_names = list(x_feats_dict.keys())\n", - " if self.load_output:\n", - " ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]\n", - " y_names = list(y_feats_dict.keys())\n", - "\n", - " # Keeping x_name and y_name, e.g. input1, for future extension\n", - " return_batch = OrderedDict([\n", - " * [(x_name, x) for x_name, x in zip(x_names, xs)],\n", - " * [(y_name, y) for y_name, y in zip(y_names, ys)],\n", - " ])\n", - " else:\n", - " return_batch = OrderedDict(\n", - " [(x_name, x) for x_name, x in zip(x_names, xs)])\n", - " return return_batch, uttid_list\n", - "\n", - " def _get_from_loader(self, filepath, filetype):\n", - " \"\"\"Return ndarray\n", - "\n", - " In order to make the fds to be opened only at the first referring,\n", - " the loader are stored in self._loaders\n", - "\n", - " >>> ndarray = loader.get_from_loader(\n", - " ... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')\n", - "\n", - " :param: str filepath:\n", - " :param: str filetype:\n", - " :return:\n", - " :rtype: np.ndarray\n", - " \"\"\"\n", - " if filetype == \"hdf5\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"hdf5\",\n", - " # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n", - " filepath, key = filepath.split(\":\", 1)\n", - "\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = h5py.File(filepath, \"r\")\n", - " self._loaders[filepath] = loader\n", - " return loader[key][()]\n", - " elif filetype == \"sound.hdf5\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.h5:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"sound.hdf5\",\n", - " # -> filepath = \"some/path.h5\", key = \"F01_050C0101_PED_REAL\"\n", - " filepath, key = filepath.split(\":\", 1)\n", - "\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = SoundHDF5File(filepath, \"r\", dtype=\"int16\")\n", - " self._loaders[filepath] = loader\n", - " array, rate = loader[key]\n", - " return array\n", - " elif filetype == \"sound\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.wav\",\n", - " # \"filetype\": \"sound\"},\n", - " # Assume PCM16\n", - " if not self.keep_all_data_on_mem:\n", - " array, _ = soundfile.read(filepath, dtype=\"int16\")\n", - " return array\n", - " if filepath not in self._loaders:\n", - " array, _ = soundfile.read(filepath, dtype=\"int16\")\n", - " self._loaders[filepath] = array\n", - " return self._loaders[filepath]\n", - " elif filetype == \"npz\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.npz:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"npz\",\n", - " filepath, key = filepath.split(\":\", 1)\n", - "\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = np.load(filepath)\n", - " self._loaders[filepath] = loader\n", - " return loader[key]\n", - " elif filetype == \"npy\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.npy\",\n", - " # \"filetype\": \"npy\"},\n", - " if not self.keep_all_data_on_mem:\n", - " return np.load(filepath)\n", - " if filepath not in self._loaders:\n", - " self._loaders[filepath] = np.load(filepath)\n", - " return self._loaders[filepath]\n", - " elif filetype in [\"mat\", \"vec\"]:\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.ark:123\",\n", - " # \"filetype\": \"mat\"}]},\n", - " # In this case, \"123\" indicates the starting points of the matrix\n", - " # load_mat can load both matrix and vector\n", - " if not self.keep_all_data_on_mem:\n", - " return kaldiio.load_mat(filepath)\n", - " if filepath not in self._loaders:\n", - " self._loaders[filepath] = kaldiio.load_mat(filepath)\n", - " return self._loaders[filepath]\n", - " elif filetype == \"scp\":\n", - " # e.g.\n", - " # {\"input\": [{\"feat\": \"some/path.scp:F01_050C0101_PED_REAL\",\n", - " # \"filetype\": \"scp\",\n", - " filepath, key = filepath.split(\":\", 1)\n", - " loader = self._loaders.get(filepath)\n", - " if loader is None:\n", - " # To avoid disk access, create loader only for the first time\n", - " loader = kaldiio.load_scp(filepath)\n", - " self._loaders[filepath] = loader\n", - " return loader[key]\n", - " else:\n", - " raise NotImplementedError(\n", - " \"Not supported: loader_type={}\".format(filetype))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "id": "monthly-muscle", - "metadata": {}, - "outputs": [], - "source": [ - "preprocess_conf=None\n", - "train_mode=True\n", - "load = LoadInputsAndTargets(\n", - " mode=\"asr\",\n", - " load_output=True,\n", - " preprocess_conf=preprocess_conf,\n", - " preprocess_args={\"train\":\n", - " train_mode}, # Switch the mode of preprocessing\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "id": "periodic-senegal", - "metadata": {}, - "outputs": [], - "source": [ - "res = load(dev_data[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 103, - "id": "502d3f4d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "2\n", - "10\n", - "10\n", - "(1174, 83) float32\n", - "(29,) int64\n" - ] - } - ], - "source": [ - "print(type(res))\n", - "print(len(res))\n", - "print(len(res[0]))\n", - "print(len(res[1]))\n", - "print(res[0][0].shape, res[0][0].dtype)\n", - "print(res[1][0].shape, res[1][0].dtype)\n", - "# Tuple[Tuple[np.ndarry], Tuple[np.ndarry]]\n", - "# 2[10, 10]\n", - "# feats, labels" - ] - }, - { - "cell_type": "code", - "execution_count": 104, - "id": "humanitarian-container", - "metadata": {}, - "outputs": [], - "source": [ - "(inputs, outputs), utts = load(dev_data[0], return_uttid=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 105, - "id": "heard-prize", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027'] 10\n", - "10\n" - ] - } - ], - "source": [ - "print(utts, len(utts))\n", - "print(len(inputs))" - ] - }, - { - "cell_type": "code", - "execution_count": 106, - "id": "convinced-animation", - "metadata": {}, - "outputs": [], - "source": [ - "import paddle\n", - "from deepspeech.io.utility import pad_list\n", - "class CustomConverter():\n", - " \"\"\"Custom batch converter.\n", - "\n", - " Args:\n", - " subsampling_factor (int): The subsampling factor.\n", - " dtype (paddle.dtype): Data type to convert.\n", - "\n", - " \"\"\"\n", - "\n", - " def __init__(self, subsampling_factor=1, dtype=np.float32):\n", - " \"\"\"Construct a CustomConverter object.\"\"\"\n", - " self.subsampling_factor = subsampling_factor\n", - " self.ignore_id = -1\n", - " self.dtype = dtype\n", - "\n", - " def __call__(self, batch):\n", - " \"\"\"Transform a batch and send it to a device.\n", - "\n", - " Args:\n", - " batch (list): The batch to transform.\n", - "\n", - " Returns:\n", - " tuple(paddle.Tensor, paddle.Tensor, paddle.Tensor)\n", - "\n", - " \"\"\"\n", - " # batch should be located in list\n", - " assert len(batch) == 1\n", - " (xs, ys), utts = batch[0]\n", - "\n", - " # perform subsampling\n", - " if self.subsampling_factor > 1:\n", - " xs = [x[::self.subsampling_factor, :] for x in xs]\n", - "\n", - " # get batch of lengths of input sequences\n", - " ilens = np.array([x.shape[0] for x in xs])\n", - "\n", - " # perform padding and convert to tensor\n", - " # currently only support real number\n", - " if xs[0].dtype.kind == \"c\":\n", - " xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)\n", - " xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)\n", - " # Note(kamo):\n", - " # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.\n", - " # Don't create ComplexTensor and give it E2E here\n", - " # because torch.nn.DataParellel can't handle it.\n", - " xs_pad = {\"real\": xs_pad_real, \"imag\": xs_pad_imag}\n", - " else:\n", - " xs_pad = pad_list(xs, 0).astype(self.dtype)\n", - "\n", - " # NOTE: this is for multi-output (e.g., speech translation)\n", - " ys_pad = pad_list(\n", - " [np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],\n", - " self.ignore_id)\n", - "\n", - " olens = np.array([y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])\n", - " return utts, xs_pad, ilens, ys_pad, olens" - ] - }, - { - "cell_type": "code", - "execution_count": 107, - "id": "0b92ade5", - "metadata": {}, - "outputs": [], - "source": [ - "convert = CustomConverter()" - ] - }, - { - "cell_type": "code", - "execution_count": 108, - "id": "8dbd847c", - "metadata": {}, - "outputs": [], - "source": [ - "utts, xs, ilen, ys, olen = convert([load(dev_data[0], return_uttid=True)])" - ] - }, - { - "cell_type": "code", - "execution_count": 109, - "id": "31c085f4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['4572-112383-0005', '6313-66125-0015', '251-137823-0022', '2277-149896-0030', '652-130726-0032', '5895-34615-0013', '1462-170138-0002', '777-126732-0008', '3660-172182-0021', '2277-149896-0027']\n", - "(10, 1174, 83)\n", - "(10,)\n", - "[1174 821 716 628 597 473 463 441 419 358]\n", - "(10, 32)\n", - "[[4502 2404 4223 3204 4502 587 1018 3861 2932 713 2458 2916 253 4508\n", - " 627 1395 713 4504 957 2761 209 2967 3173 3918 2598 4100 3 2816\n", - " 4990 -1 -1 -1]\n", - " [1005 451 210 278 3411 206 482 2307 573 4502 3848 4577 4273 2388\n", - " 4444 89 4919 278 1264 4501 2371 3 139 113 2603 4962 3158 3325\n", - " 4577 814 4587 1422]\n", - " [2345 4144 2291 200 713 2345 532 999 2458 3076 545 2458 4832 3038\n", - " 4499 482 2812 1260 3080 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [2345 832 4577 4920 4501 2345 2298 1236 381 288 389 101 2495 4172\n", - " 4843 3233 3245 4501 2345 2298 3987 4502 3023 3353 2345 1361 1635 2603\n", - " 4723 2371 -1 -1]\n", - " [4502 4207 432 3204 4502 2396 125 935 433 2598 483 18 327 2\n", - " 389 627 4512 2340 713 482 1981 4525 4031 269 2030 1340 101 2495\n", - " 4013 4844 -1 -1]\n", - " [4502 4892 3204 1892 3780 389 482 2774 3013 89 192 2495 4502 3475\n", - " 389 66 370 343 404 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [2458 2314 4577 2340 2863 1254 303 269 2 389 932 2079 4577 299\n", - " 195 3233 4508 2 89 814 3144 1091 3204 3250 2193 3414 -1 -1\n", - " -1 -1 -1 -1]\n", - " [2391 1785 443 78 39 4962 2340 829 599 4593 278 4681 202 407\n", - " 269 194 182 4577 482 4308 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [ 627 4873 2175 363 202 404 1018 4577 4502 3412 4875 2286 107 122\n", - " 4832 2345 3896 89 2368 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]\n", - " [ 481 174 474 599 1881 3252 2842 742 4502 2545 107 88 3204 4525\n", - " 4517 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1\n", - " -1 -1 -1 -1]]\n", - "[29 32 19 30 30 19 26 20 19 15]\n", - "float32\n", - "int64\n", - "int64\n", - "int64\n" - ] - } - ], - "source": [ - "print(utts)\n", - "print(xs.shape)\n", - "print(ilen.shape)\n", - "print(ilen)\n", - "print(ys.shape)\n", - "print(ys)\n", - "print(olen)\n", - "print(xs.dtype)\n", - "print(ilen.dtype)\n", - "print(ys.dtype)\n", - "print(olen.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 110, - "id": "72e9ba60", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 230, - "id": "64593e5f", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "from paddle.io import DataLoader\n", - "\n", - "from deepspeech.frontend.utility import read_manifest\n", - "from deepspeech.io.batchfy import make_batchset\n", - "from deepspeech.io.converter import CustomConverter\n", - "from deepspeech.io.dataset import TransformDataset\n", - "from deepspeech.io.reader import LoadInputsAndTargets\n", - "from deepspeech.utils.log import Log\n", - "\n", - "\n", - "logger = Log(__name__).getlog()\n", - "\n", - "\n", - "class BatchDataLoader():\n", - " def __init__(self,\n", - " json_file: str,\n", - " train_mode: bool,\n", - " sortagrad: bool=False,\n", - " batch_size: int=0,\n", - " maxlen_in: float=float('inf'),\n", - " maxlen_out: float=float('inf'),\n", - " minibatches: int=0,\n", - " mini_batch_size: int=1,\n", - " batch_count: str='auto',\n", - " batch_bins: int=0,\n", - " batch_frames_in: int=0,\n", - " batch_frames_out: int=0,\n", - " batch_frames_inout: int=0,\n", - " preprocess_conf=None,\n", - " n_iter_processes: int=1,\n", - " subsampling_factor: int=1,\n", - " num_encs: int=1):\n", - " self.json_file = json_file\n", - " self.train_mode = train_mode\n", - " self.use_sortagrad = sortagrad == -1 or sortagrad > 0\n", - " self.batch_size = batch_size\n", - " self.maxlen_in = maxlen_in\n", - " self.maxlen_out = maxlen_out\n", - " self.batch_count = batch_count\n", - " self.batch_bins = batch_bins\n", - " self.batch_frames_in = batch_frames_in\n", - " self.batch_frames_out = batch_frames_out\n", - " self.batch_frames_inout = batch_frames_inout\n", - " self.subsampling_factor = subsampling_factor\n", - " self.num_encs = num_encs\n", - " self.preprocess_conf = preprocess_conf\n", - " self.n_iter_processes = n_iter_processes\n", - "\n", - " \n", - " # read json data\n", - " self.data_json = read_manifest(json_file)\n", - "\n", - " # make minibatch list (variable length)\n", - " self.minibaches = make_batchset(\n", - " self.data_json,\n", - " batch_size,\n", - " maxlen_in,\n", - " maxlen_out,\n", - " minibatches, # for debug\n", - " min_batch_size=mini_batch_size,\n", - " shortest_first=self.use_sortagrad,\n", - " count=batch_count,\n", - " batch_bins=batch_bins,\n", - " batch_frames_in=batch_frames_in,\n", - " batch_frames_out=batch_frames_out,\n", - " batch_frames_inout=batch_frames_inout,\n", - " iaxis=0,\n", - " oaxis=0, )\n", - "\n", - " # data reader\n", - " self.reader = LoadInputsAndTargets(\n", - " mode=\"asr\",\n", - " load_output=True,\n", - " preprocess_conf=preprocess_conf,\n", - " preprocess_args={\"train\":\n", - " train_mode}, # Switch the mode of preprocessing\n", - " )\n", - "\n", - " # Setup a converter\n", - " if num_encs == 1:\n", - " self.converter = CustomConverter(\n", - " subsampling_factor=subsampling_factor, dtype=np.float32)\n", - " else:\n", - " assert NotImplementedError(\"not impl CustomConverterMulEnc.\")\n", - "\n", - " # hack to make batchsize argument as 1\n", - " # actual bathsize is included in a list\n", - " # default collate function converts numpy array to pytorch tensor\n", - " # we used an empty collate function instead which returns list\n", - " self.dataset = TransformDataset(self.minibaches, \n", - " lambda data: self.converter([self.reader(data, return_uttid=True)]))\n", - " self.dataloader = DataLoader(\n", - " dataset=self.dataset,\n", - " batch_size=1,\n", - " shuffle=not use_sortagrad if train_mode else False,\n", - " collate_fn=lambda x: x[0],\n", - " num_workers=n_iter_processes, )\n", - "\n", - " def __repr__(self):\n", - " echo = f\"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> \"\n", - " echo += f\"train_mode: {self.train_mode}, \"\n", - " echo += f\"sortagrad: {self.use_sortagrad}, \"\n", - " echo += f\"batch_size: {self.batch_size}, \"\n", - " echo += f\"maxlen_in: {self.maxlen_in}, \"\n", - " echo += f\"maxlen_out: {self.maxlen_out}, \"\n", - " echo += f\"batch_count: {self.batch_count}, \"\n", - " echo += f\"batch_bins: {self.batch_bins}, \"\n", - " echo += f\"batch_frames_in: {self.batch_frames_in}, \"\n", - " echo += f\"batch_frames_out: {self.batch_frames_out}, \"\n", - " echo += f\"batch_frames_inout: {self.batch_frames_inout}, \"\n", - " echo += f\"subsampling_factor: {self.subsampling_factor}, \"\n", - " echo += f\"num_encs: {self.num_encs}, \"\n", - " echo += f\"num_workers: {self.n_iter_processes}, \"\n", - " echo += f\"file: {self.json_file}\"\n", - " return echo\n", - " \n", - " def __len__(self):\n", - " return len(self.dataloader)\n", - " \n", - " def __iter__(self):\n", - " return self.dataloader.__iter__()\n", - " \n", - " def __call__(self):\n", - " return self.__iter__()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 231, - "id": "fcea3fd0", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[INFO 2021/08/18 07:42:23 batchfy.py:399] count is auto detected as seq\n", - "[INFO 2021/08/18 07:42:23 batchfy.py:423] # utts: 5542\n", - "[INFO 2021/08/18 07:42:23 batchfy.py:466] # minibatches: 278\n" - ] - } - ], - "source": [ - "train = BatchDataLoader(dev_data, True, batch_size=20)" - ] - }, - { - "cell_type": "code", - "execution_count": 232, - "id": "e2a2c9a8", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "278\n", - "['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'auto_collate_batch', 'batch_sampler', 'batch_size', 'collate_fn', 'dataset', 'dataset_kind', 'feed_list', 'from_dataset', 'from_generator', 'num_workers', 'pin_memory', 'places', 'return_list', 'timeout', 'use_buffer_reader', 'use_shared_memory', 'worker_init_fn']\n", - "<__main__.BatchDataLoader object at 0x7fdddba35470> train_mode: True, sortagrad: False, batch_size: 20, maxlen_in: inf, maxlen_out: inf, batch_count: auto, batch_bins: 0, batch_frames_in: 0, batch_frames_out: 0, batch_frames_inout: 0, subsampling_factor: 1, num_encs: 1, num_workers: 1, file: /workspace/zhanghui/DeepSpeech-2.x/examples/librispeech/s2/data/manifest.dev\n", - "278\n" - ] - } - ], - "source": [ - "print(len(train.dataloader))\n", - "print(dir(train.dataloader))\n", - "print(train)\n", - "print(len(train))" - ] - }, - { - "cell_type": "code", - "execution_count": 220, - "id": "a5ba7d6e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['7601-101619-0003', '1255-138279-0000', '1272-128104-0004', '6123-59150-0027', '2078-142845-0025', '7850-73752-0018', '4570-24733-0004', '2506-169427-0002', '7601-101619-0004', '3170-137482-0000', '6267-53049-0019', '4570-14911-0009', '174-168635-0018', '7601-291468-0004', '3576-138058-0022', '1919-142785-0007', '6467-62797-0007', '4153-61735-0005', '1686-142278-0003', '2506-169427-0000']\n", - "Tensor(shape=[20, 2961, 83], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[-1.99415934, -1.80315673, -1.88801885, ..., 0.86933994, -0.59853148, 0.02596200],\n", - " [-1.95346808, -1.84891188, -2.17492867, ..., 0.83640492, -0.59853148, -0.11333394],\n", - " [-2.27899861, -2.21495342, -2.58480024, ..., 0.91874266, -0.59853148, -0.31453922],\n", - " ...,\n", - " [-2.64522028, -2.35221887, -2.91269732, ..., 1.48994756, -0.16100442, 0.36646330],\n", - " [-2.40107250, -2.21495342, -2.37986445, ..., 1.44072104, -0.13220564, 0.12656468],\n", - " [-2.15692472, -1.89466715, -2.25690317, ..., 1.31273174, -0.09620714, -0.15202725]],\n", - "\n", - " [[-0.28859532, -0.29033494, -0.86576819, ..., 1.37753224, -0.30570769, 0.25806731],\n", - " [-0.20149794, -0.17814466, -0.59891301, ..., 1.35188794, -0.30570769, -0.02964944],\n", - " [-0.34947991, -0.33597648, -0.96877253, ..., 1.38394332, -0.30570769, -0.38376236],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-0.44914246, -0.33902276, -0.78237975, ..., 1.38218808, 0.29214793, -0.16815147],\n", - " [-0.55490732, -0.41596055, -0.84425378, ..., 1.34530187, 0.25002354, -0.04004869],\n", - " [-0.83694696, -0.62112784, -1.07112527, ..., 1.19160914, 0.20789915, 0.37984371],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[-1.24343657, -0.94188881, -1.41092563, ..., 0.96716309, 0.60345763, 0.15360183],\n", - " [-1.19466043, -0.80585432, -0.49723154, ..., 1.06735480, 0.60345763, 0.14511746],\n", - " [-0.94079566, -0.59330046, -0.40948665, ..., 0.82244170, 0.55614340, 0.28086722],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.21757117, 0.11361472, -0.33262897, ..., 0.76338506, -0.10711290, -0.57754958],\n", - " [-1.00205481, -0.61152041, -0.47124696, ..., 1.11897349, -0.10711290, 0.24931324],\n", - " [-1.03929281, -1.20336759, -1.16433656, ..., 0.88888687, -0.10711290, -0.04115745],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-1.25289667, -1.05046368, -0.82881606, ..., 1.23991334, 0.61702502, 0.05275881],\n", - " [-1.19659519, -0.78677225, -0.80407262, ..., 1.27644968, 0.61702502, -0.35079369],\n", - " [-1.49687004, -1.01750231, -0.82881606, ..., 1.29106426, 0.65006059, 0.17958963],\n", - " ...,\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [ 0. , 0. , 0. , ..., 0. , 0. , 0. ]]])\n", - "Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [2961, 2948, 2938, 2907, 2904, 2838, 2832, 2819, 2815, 2797, 2775, 2710, 2709, 2696, 2688, 2661, 2616, 2595, 2589, 2576])\n", - "Tensor(shape=[20, 133], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[3098, 1595, 389, ..., -1 , -1 , -1 ],\n", - " [2603, 4832, 482, ..., -1 , -1 , -1 ],\n", - " [2796, 303, 269, ..., -1 , -1 , -1 ],\n", - " ...,\n", - " [3218, 3673, 206, ..., -1 , -1 , -1 ],\n", - " [2371, 4832, 4031, ..., -1 , -1 , -1 ],\n", - " [2570, 2433, 4285, ..., -1 , -1 , -1 ]])\n", - "Tensor(shape=[20], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [80 , 83 , 102, 133, 82 , 102, 71 , 91 , 68 , 81 , 86 , 67 , 71 , 95 , 65 , 88 , 97 , 98 , 89 , 72 ])\n" - ] - } - ], - "source": [ - "for batch in train:\n", - " utts, xs, ilens, ys, olens = batch\n", - " print(utts)\n", - " print(xs)\n", - " print(ilens)\n", - " print(ys)\n", - " print(olens)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3c974a1e", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/hack_api_test.ipynb b/.notebook/hack_api_test.ipynb deleted file mode 100644 index f653084e6..000000000 --- a/.notebook/hack_api_test.ipynb +++ /dev/null @@ -1,290 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "breeding-haven", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/home/ssd5/zhanghui/DeepSpeech2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "appropriate-theta", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "LICENSE deepspeech examples\t\t requirements.txt tools\r\n", - "README.md docs\t libsndfile-1.0.28\t setup.sh\t utils\r\n", - "README_cn.md env.sh\t libsndfile-1.0.28.tar.gz tests\r\n" - ] - } - ], - "source": [ - "!ls" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "entire-bloom", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "WARNING:root:override cat of paddle.Tensor if exists or register, remove this when fixed!\n", - "WARNING:root:register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "WARNING:root:register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "WARNING:root:register user repeat to paddle.Tensor, remove this when fixed!\n", - "WARNING:root:register user glu to paddle.nn.functional, remove this when fixed!\n", - "WARNING:root:register user GLU to paddle.nn, remove this when fixed!\n", - "WARNING:root:register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "WARNING:root:override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n" - ] - } - ], - "source": [ - "from deepspeech.modules import loss" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "governmental-aircraft", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "import paddle" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "proprietary-disaster", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - " paddle.VarBase>" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "paddle.Tensor.repeat" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "first-diagram", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "paddle.Tensor.size" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "intelligent-david", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "paddle.Tensor.cat" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "bronze-tenant", - "metadata": {}, - "outputs": [], - "source": [ - "a = paddle.to_tensor([12,32, 10, 12, 123,32 ,4])" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "balanced-bearing", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "7" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.size" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "extreme-republic", - "metadata": {}, - "outputs": [], - "source": [ - "def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:\n", - " nargs = len(args)\n", - " assert (nargs <= 1)\n", - " s = paddle.shape(xs)\n", - " if nargs == 1:\n", - " return s[args[0]]\n", - " else:\n", - " return s\n", - "\n", - "# logger.warn(\n", - "# \"override size of paddle.Tensor if exists or register, remove this when fixed!\"\n", - "# )\n", - "paddle.Tensor.size = size" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "gross-addiction", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [7])" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.size(0)\n", - "a.size()" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "adverse-dining", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [7])" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a.size()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "popular-potato", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/jit_infer.ipynb b/.notebook/jit_infer.ipynb deleted file mode 100644 index 20882c1ae..000000000 --- a/.notebook/jit_infer.ipynb +++ /dev/null @@ -1,672 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/home/ssd5/zhanghui/DeepSpeech2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-03-26 02:55:23,873 - WARNING - register user softmax to paddle, remove this when fixed!\n", - "2021-03-26 02:55:23,875 - WARNING - register user sigmoid to paddle, remove this when fixed!\n", - "2021-03-26 02:55:23,875 - WARNING - register user relu to paddle, remove this when fixed!\n", - "2021-03-26 02:55:23,876 - WARNING - override cat of paddle if exists or register, remove this when fixed!\n", - "2021-03-26 02:55:23,876 - WARNING - override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "2021-03-26 02:55:23,877 - WARNING - override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "2021-03-26 02:55:23,877 - WARNING - override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "2021-03-26 02:55:23,878 - WARNING - register user view to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,878 - WARNING - register user view_as to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,879 - WARNING - register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,880 - WARNING - register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,880 - WARNING - register user fill_ to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,881 - WARNING - register user repeat to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,881 - WARNING - register user softmax to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,882 - WARNING - register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,882 - WARNING - register user relu to paddle.Tensor, remove this when fixed!\n", - "2021-03-26 02:55:23,883 - WARNING - register user glu to paddle.nn.functional, remove this when fixed!\n", - "2021-03-26 02:55:23,883 - WARNING - override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "2021-03-26 02:55:23,884 - WARNING - register user GLU to paddle.nn, remove this when fixed!\n", - "2021-03-26 02:55:23,884 - WARNING - register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n" - ] - } - ], - "source": [ - "import os\n", - "import time\n", - "import argparse\n", - "import functools\n", - "import paddle\n", - "import numpy as np\n", - "\n", - "from deepspeech.utils.socket_server import warm_up_test\n", - "from deepspeech.utils.socket_server import AsrTCPServer\n", - "from deepspeech.utils.socket_server import AsrRequestHandler\n", - "\n", - "from deepspeech.training.cli import default_argument_parser\n", - "from deepspeech.exps.deepspeech2.config import get_cfg_defaults\n", - "\n", - "from deepspeech.frontend.utility import read_manifest\n", - "from deepspeech.utils.utility import add_arguments, print_arguments\n", - "\n", - "from deepspeech.models.ds2 import DeepSpeech2Model\n", - "from deepspeech.models.ds2 import DeepSpeech2InferModel\n", - "from deepspeech.io.dataset import ManifestDataset\n", - "\n", - "\n", - "\n", - "from deepspeech.frontend.utility import read_manifest" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.0.0\n", - "e7f28d6c0db54eb9c9a810612300b526687e56a6\n", - "OFF\n", - "OFF\n", - "commit: e7f28d6c0db54eb9c9a810612300b526687e56a6\n", - "None\n", - "0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ssd5/zhanghui/DeepSpeech2.x/tools/venv-dev/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - }, - { - "data": { - "text/plain": [ - "['__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " 'commit',\n", - " 'full_version',\n", - " 'istaged',\n", - " 'major',\n", - " 'minor',\n", - " 'mkl',\n", - " 'patch',\n", - " 'rc',\n", - " 'show',\n", - " 'with_mkl']" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "print(paddle.__version__)\n", - "print(paddle.version.commit)\n", - "print(paddle.version.with_mkl)\n", - "print(paddle.version.mkl())\n", - "print(paddle.version.show())\n", - "print(paddle.version.patch)\n", - "dir(paddle.version)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "data:\n", - " augmentation_config: conf/augmentation.config\n", - " batch_size: 64\n", - " dev_manifest: data/manifest.dev\n", - " keep_transcription_text: False\n", - " max_duration: 27.0\n", - " max_freq: None\n", - " mean_std_filepath: examples/aishell/data/mean_std.npz\n", - " min_duration: 0.0\n", - " n_fft: None\n", - " num_workers: 0\n", - " random_seed: 0\n", - " shuffle_method: batch_shuffle\n", - " sortagrad: True\n", - " specgram_type: linear\n", - " stride_ms: 10.0\n", - " target_dB: -20\n", - " target_sample_rate: 16000\n", - " test_manifest: examples/aishell/data/manifest.test\n", - " train_manifest: data/manifest.train\n", - " use_dB_normalization: True\n", - " vocab_filepath: examples/aishell/data/vocab.txt\n", - " window_ms: 20.0\n", - "decoding:\n", - " alpha: 2.6\n", - " batch_size: 128\n", - " beam_size: 300\n", - " beta: 5.0\n", - " cutoff_prob: 0.99\n", - " cutoff_top_n: 40\n", - " decoding_method: ctc_beam_search\n", - " error_rate_type: cer\n", - " lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm\n", - " num_proc_bsearch: 10\n", - "model:\n", - " num_conv_layers: 2\n", - " num_rnn_layers: 3\n", - " rnn_layer_size: 1024\n", - " share_rnn_weights: False\n", - " use_gru: True\n", - "training:\n", - " global_grad_clip: 5.0\n", - " lr: 0.0005\n", - " lr_decay: 0.83\n", - " n_epoch: 30\n", - " weight_decay: 1e-06\n", - "----------- Configuration Arguments -----------\n", - "checkpoint_path: examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725\n", - "config: examples/aishell/conf/deepspeech2.yaml\n", - "device: gpu\n", - "dump_config: None\n", - "export_path: None\n", - "host_ip: localhost\n", - "host_port: 8086\n", - "model_dir: None\n", - "model_file: examples/aishell/jit.model.pdmodel\n", - "nprocs: 1\n", - "opts: ['data.test_manifest', 'examples/aishell/data/manifest.test', 'data.mean_std_filepath', 'examples/aishell/data/mean_std.npz', 'data.vocab_filepath', 'examples/aishell/data/vocab.txt']\n", - "output: None\n", - "params_file: examples/aishell/jit.model.pdiparams\n", - "speech_save_dir: demo_cache\n", - "use_gpu: False\n", - "warmup_manifest: examples/aishell/data/manifest.test\n", - "------------------------------------------------\n" - ] - } - ], - "source": [ - "parser = default_argument_parser()\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "add_arg('host_ip', str,\n", - " 'localhost',\n", - " \"Server's IP address.\")\n", - "add_arg('host_port', int, 8086, \"Server's IP port.\")\n", - "add_arg('speech_save_dir', str,\n", - " 'demo_cache',\n", - " \"Directory to save demo audios.\")\n", - "add_arg('warmup_manifest', \n", - " str, \n", - " \"examples/aishell/data/manifest.test\", \n", - " \"Filepath of manifest to warm up.\")\n", - "add_arg(\n", - " \"--model_file\",\n", - " type=str,\n", - " default=\"examples/aishell/jit.model.pdmodel\",\n", - " help=\"Model filename, Specify this when your model is a combined model.\"\n", - ")\n", - "add_arg(\n", - " \"--params_file\",\n", - " type=str,\n", - " default=\"examples/aishell/jit.model.pdiparams\",\n", - " help=\n", - " \"Parameter filename, Specify this when your model is a combined model.\"\n", - ")\n", - "add_arg(\n", - " \"--model_dir\",\n", - " type=str,\n", - " default=None,\n", - " help=\n", - " \"Model dir, If you load a non-combined model, specify the directory of the model.\"\n", - ")\n", - "add_arg(\"--use_gpu\",type=bool,default=False, help=\"Whether use gpu.\")\n", - "\n", - "\n", - "args = parser.parse_args(\n", - " \"--checkpoint_path examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725 --config examples/aishell/conf/deepspeech2.yaml --opts data.test_manifest examples/aishell/data/manifest.test data.mean_std_filepath examples/aishell/data/mean_std.npz data.vocab_filepath examples/aishell/data/vocab.txt\".split()\n", - ")\n", - "\n", - "\n", - "config = get_cfg_defaults()\n", - "if args.config:\n", - " config.merge_from_file(args.config)\n", - "if args.opts:\n", - " config.merge_from_list(args.opts)\n", - "config.freeze()\n", - "print(config)\n", - "\n", - "args.warmup_manifest = config.data.test_manifest\n", - "\n", - "print_arguments(args)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = ManifestDataset(\n", - " config.data.test_manifest,\n", - " config.data.unit_type,\n", - " config.data.vocab_filepath,\n", - " config.data.mean_std_filepath,\n", - " augmentation_config=\"{}\",\n", - " max_duration=config.data.max_duration,\n", - " min_duration=config.data.min_duration,\n", - " stride_ms=config.data.stride_ms,\n", - " window_ms=config.data.window_ms,\n", - " n_fft=config.data.n_fft,\n", - " max_freq=config.data.max_freq,\n", - " target_sample_rate=config.data.target_sample_rate,\n", - " specgram_type=config.data.specgram_type,\n", - " feat_dim=config.data.feat_dim,\n", - " delta_delta=config.data.delat_delta,\n", - " use_dB_normalization=config.data.use_dB_normalization,\n", - " target_dB=config.data.target_dB,\n", - " random_seed=config.data.random_seed,\n", - " keep_transcription_text=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-03-26 02:55:57,930 - INFO - [checkpoint] Rank 0: loaded model from examples/aishell/ckpt-loss2e-3-0.83-5/checkpoints/step-11725.pdparams\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "layer summary:\n", - "encoder.conv.conv_in.conv.weight|[32, 1, 41, 11]|14432\n", - "encoder.conv.conv_in.bn.weight|[32]|32\n", - "encoder.conv.conv_in.bn.bias|[32]|32\n", - "encoder.conv.conv_in.bn._mean|[32]|32\n", - "encoder.conv.conv_in.bn._variance|[32]|32\n", - "encoder.conv.conv_stack.0.conv.weight|[32, 32, 21, 11]|236544\n", - "encoder.conv.conv_stack.0.bn.weight|[32]|32\n", - "encoder.conv.conv_stack.0.bn.bias|[32]|32\n", - "encoder.conv.conv_stack.0.bn._mean|[32]|32\n", - "encoder.conv.conv_stack.0.bn._variance|[32]|32\n", - "encoder.rnn.rnn_stacks.0.fw_fc.weight|[1312, 3072]|4030464\n", - "encoder.rnn.rnn_stacks.0.fw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_fc.weight|[1312, 3072]|4030464\n", - "encoder.rnn.rnn_stacks.0.bw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.fw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.bw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.fw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.0.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.0.bw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.1.fw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.1.bw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.fw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.bw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.fw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.1.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.1.bw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.2.fw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_fc.weight|[2048, 3072]|6291456\n", - "encoder.rnn.rnn_stacks.2.bw_bn.weight|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_bn.bias|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_bn._mean|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_bn._variance|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.fw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.bw_cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.fw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.fw_rnn.cell.bias_hh|[3072]|3072\n", - "encoder.rnn.rnn_stacks.2.bw_rnn.cell.weight_hh|[3072, 1024]|3145728\n", - "encoder.rnn.rnn_stacks.2.bw_rnn.cell.bias_hh|[3072]|3072\n", - "decoder.ctc_lo.weight|[2048, 4300]|8806400\n", - "decoder.ctc_lo.bias|[4300]|4300\n", - "layer has 66 parameters, 80148012 elements.\n" - ] - } - ], - "source": [ - "model = DeepSpeech2InferModel.from_pretrained(dataset, config,\n", - " args.checkpoint_path)\n", - "model.eval()" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "examples/aishell/jit.model.pdmodel\n", - "examples/aishell/jit.model.pdiparams\n", - "0\n", - "False\n" - ] - } - ], - "source": [ - "\n", - "from paddle.inference import Config\n", - "from paddle.inference import PrecisionType\n", - "from paddle.inference import create_predictor\n", - "\n", - "args.use_gpu=False\n", - "paddle.set_device('cpu')\n", - "\n", - "def init_predictor(args):\n", - " if args.model_dir is not None:\n", - " config = Config(args.model_dir)\n", - " else:\n", - " config = Config(args.model_file, args.params_file)\n", - "\n", - " if args.use_gpu:\n", - " config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)\n", - "# config.enable_tensorrt_engine(precision_mode=PrecisionType.Float32,\n", - "# use_calib_mode=True) # 开启TensorRT预测,精度为fp32,开启int8离线量化\n", - " else:\n", - " # If not specific mkldnn, you can set the blas thread.\n", - " # The thread num should not be greater than the number of cores in the CPU.\n", - " config.set_cpu_math_library_num_threads(1)\n", - " config.enable_mkldnn()\n", - " \n", - " config.enable_memory_optim()\n", - " config.switch_ir_optim(True)\n", - " \n", - " print(config.model_dir())\n", - " print(config.prog_file())\n", - " print(config.params_file())\n", - " print(config.gpu_device_id())\n", - " print(args.use_gpu)\n", - " predictor = create_predictor(config)\n", - " return predictor\n", - "\n", - "def run(predictor, audio, audio_len):\n", - " # copy img data to input tensor\n", - " input_names = predictor.get_input_names()\n", - " for i, name in enumerate(input_names):\n", - " print(\"input:\", i, name)\n", - " \n", - " audio_tensor = predictor.get_input_handle('audio')\n", - " audio_tensor.reshape(audio.shape)\n", - " audio_tensor.copy_from_cpu(audio.copy())\n", - " \n", - " audiolen_tensor = predictor.get_input_handle('audio_len')\n", - " audiolen_tensor.reshape(audio_len.shape)\n", - " audiolen_tensor.copy_from_cpu(audio_len.copy())\n", - "\n", - " output_names = predictor.get_output_names()\n", - " for i, name in enumerate(output_names):\n", - " print(\"output:\", i, name)\n", - "\n", - " # do the inference\n", - " predictor.run()\n", - "\n", - " results = []\n", - " # get out data from output tensor\n", - " output_names = predictor.get_output_names()\n", - " for i, name in enumerate(output_names):\n", - " output_tensor = predictor.get_output_handle(name)\n", - " output_data = output_tensor.copy_to_cpu()\n", - " results.append(output_data)\n", - "\n", - " return results\n", - "\n", - "\n", - "predictor = init_predictor(args)\n", - "\n", - "def file_to_transcript(filename):\n", - " print(filename)\n", - " feature = dataset.process_utterance(filename, \"\")\n", - " audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n", - " audio_len = feature[0].shape[1]\n", - " audio_len = np.array([audio_len]).astype('int64') # [1]\n", - " \n", - " \n", - " i_probs = run(predictor, audio, audio_len)\n", - " print('jit:', i_probs[0], type(i_probs[0]))\n", - " \n", - " audio = paddle.to_tensor(audio)\n", - " audio_len = paddle.to_tensor(audio_len)\n", - " print(audio.shape)\n", - " print(audio_len.shape)\n", - " \n", - " #eouts, eouts_len = model.encoder(audio, audio_len)\n", - " #probs = model.decoder.softmax(eouts)\n", - " probs = model.forward(audio, audio_len)\n", - " print('paddle:', probs.numpy())\n", - " \n", - " flag = np.allclose(i_probs[0], probs.numpy())\n", - " print(flag)\n", - " \n", - " return probs\n", - "\n", - "# result_transcript = model.decode(\n", - "# audio,\n", - "# audio_len,\n", - "# vocab_list=dataset.vocab_list,\n", - "# decoding_method=config.decoding.decoding_method,\n", - "# lang_model_path=config.decoding.lang_model_path,\n", - "# beam_alpha=config.decoding.alpha,\n", - "# beam_beta=config.decoding.beta,\n", - "# beam_size=config.decoding.beam_size,\n", - "# cutoff_prob=config.decoding.cutoff_prob,\n", - "# cutoff_top_n=config.decoding.cutoff_top_n,\n", - "# num_processes=config.decoding.num_proc_bsearch)\n", - "# return result_transcript[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Warm-up Test Case %d: %s 0 /home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n", - "/home/ssd5/zhanghui/DeepSpeech2.x/examples/aishell/../dataset/aishell/data_aishell/wav/test/S0764/BAC009S0764W0124.wav\n", - "input: 0 audio\n", - "input: 1 audio_len\n", - "output: 0 tmp_75\n", - "jit: [[[8.91786298e-12 4.45648032e-12 3.67572750e-09 ... 8.91767563e-12\n", - " 8.91573707e-12 4.64317296e-08]\n", - " [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n", - " 1.55891342e-15 9.99992609e-01]\n", - " [1.24638127e-17 7.61802427e-16 2.93265812e-14 ... 1.24633371e-17\n", - " 1.24587264e-17 1.00000000e+00]\n", - " ...\n", - " [4.37488240e-15 2.43676260e-12 1.98770514e-12 ... 4.37479896e-15\n", - " 4.37354747e-15 1.00000000e+00]\n", - " [3.89334696e-13 1.66754856e-11 1.42900388e-11 ... 3.89329492e-13\n", - " 3.89252270e-13 1.00000000e+00]\n", - " [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n", - " 1.00334095e-10 9.99998808e-01]]] \n", - "[1, 161, 522]\n", - "[1]\n", - "paddle: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n", - " 8.91577090e-12 4.64319072e-08]\n", - " [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n", - " 1.55891342e-15 9.99992609e-01]\n", - " [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n", - " 1.24587735e-17 1.00000000e+00]\n", - " ...\n", - " [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n", - " 4.37354747e-15 1.00000000e+00]\n", - " [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n", - " 3.89253761e-13 1.00000000e+00]\n", - " [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n", - " 1.00334095e-10 9.99998808e-01]]]\n", - "False\n" - ] - } - ], - "source": [ - "manifest = read_manifest(args.warmup_manifest)\n", - "\n", - "for idx, sample in enumerate(manifest[:1]):\n", - " print(\"Warm-up Test Case %d: %s\", idx, sample['audio_filepath'])\n", - " start_time = time.time()\n", - " transcript = file_to_transcript(sample['audio_filepath'])\n", - " finish_time = time.time()\n", - "# print(\"Response Time: %f, Transcript: %s\" %\n", - "# (finish_time - start_time, transcript))\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 161, 522) (1,)\n", - "input: 0 audio\n", - "input: 1 audio_len\n", - "output: 0 tmp_75\n", - "jit: [[[8.91789680e-12 4.45649724e-12 3.67574149e-09 ... 8.91770945e-12\n", - " 8.91577090e-12 4.64319072e-08]\n", - " [1.55950222e-15 2.62794089e-14 4.50423509e-12 ... 1.55944271e-15\n", - " 1.55891342e-15 9.99992609e-01]\n", - " [1.24638599e-17 7.61805339e-16 2.93267472e-14 ... 1.24633842e-17\n", - " 1.24587735e-17 1.00000000e+00]\n", - " ...\n", - " [4.37488240e-15 2.43676737e-12 1.98770514e-12 ... 4.37479896e-15\n", - " 4.37354747e-15 1.00000000e+00]\n", - " [3.89336187e-13 1.66755481e-11 1.42900925e-11 ... 3.89330983e-13\n", - " 3.89253761e-13 1.00000000e+00]\n", - " [1.00349985e-10 2.56293708e-10 2.91177582e-10 ... 1.00347876e-10\n", - " 1.00334095e-10 9.99998808e-01]]]\n" - ] - } - ], - "source": [ - "def test(filename):\n", - " feature = dataset.process_utterance(filename, \"\")\n", - " audio = np.array([feature[0]]).astype('float32') #[1, D, T]\n", - " audio_len = feature[0].shape[1]\n", - " audio_len = np.array([audio_len]).astype('int64') # [1]\n", - " \n", - " print(audio.shape, audio_len.shape)\n", - "\n", - " i_probs = run(predictor, audio, audio_len)\n", - " print('jit:', i_probs[0])\n", - " return i_probs\n", - " \n", - "probs = test(sample['audio_filepath'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/.notebook/layer_norm_test.ipynb b/.notebook/layer_norm_test.ipynb deleted file mode 100644 index eac3566ff..000000000 --- a/.notebook/layer_norm_test.ipynb +++ /dev/null @@ -1,229 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 32, - "id": "academic-surname", - "metadata": {}, - "outputs": [], - "source": [ - "import paddle\n", - "from paddle import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "fundamental-treasure", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameter containing:\n", - "Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])\n", - "Parameter containing:\n", - "Tensor(shape=[256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])\n" - ] - } - ], - "source": [ - "L = nn.LayerNorm(256, epsilon=1e-12)\n", - "for p in L.parameters():\n", - " print(p)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "consolidated-elephant", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "moderate-noise", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float64\n" - ] - } - ], - "source": [ - "x = np.random.randn(2, 51, 256)\n", - "print(x.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "cooked-progressive", - "metadata": {}, - "outputs": [], - "source": [ - "y = L(paddle.to_tensor(x, dtype='float32'))" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "optimum-milwaukee", - "metadata": {}, - "outputs": [], - "source": [ - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "viral-indian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parameter containing:\n", - "tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1.], requires_grad=True)\n", - "Parameter containing:\n", - "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " requires_grad=True)\n" - ] - } - ], - "source": [ - "TL = torch.nn.LayerNorm(256, eps=1e-12)\n", - "for p in TL.parameters():\n", - " print(p)" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "skilled-vietnamese", - "metadata": {}, - "outputs": [], - "source": [ - "ty = TL(torch.tensor(x, dtype=torch.float32))" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "id": "incorrect-allah", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.allclose(y.numpy(), ty.detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "prostate-cameroon", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 52, - "id": "governmental-surge", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = np.random.randn(2, 256)\n", - "y = L(paddle.to_tensor(x, dtype='float32'))\n", - "ty = TL(torch.tensor(x, dtype=torch.float32))\n", - "np.allclose(y.numpy(), ty.detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "confidential-jacket", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/mask_and_masked_fill_test.ipynb b/.notebook/mask_and_masked_fill_test.ipynb deleted file mode 100644 index 265ec536b..000000000 --- a/.notebook/mask_and_masked_fill_test.ipynb +++ /dev/null @@ -1,449 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "primary-organic", - "metadata": {}, - "outputs": [], - "source": [ - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "stopped-semester", - "metadata": {}, - "outputs": [], - "source": [ - "def mask_finished_scores(score: torch.Tensor,\n", - " flag: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"\n", - " If a sequence is finished, we only allow one alive branch. This function\n", - " aims to give one branch a zero score and the rest -inf score.\n", - " Args:\n", - " score (torch.Tensor): A real value array with shape\n", - " (batch_size * beam_size, beam_size).\n", - " flag (torch.Tensor): A bool array with shape\n", - " (batch_size * beam_size, 1).\n", - " Returns:\n", - " torch.Tensor: (batch_size * beam_size, beam_size).\n", - " \"\"\"\n", - " beam_size = score.size(-1)\n", - " zero_mask = torch.zeros_like(flag, dtype=torch.bool)\n", - " if beam_size > 1:\n", - " unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),\n", - " dim=1)\n", - " finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),\n", - " dim=1)\n", - " else:\n", - " unfinished = zero_mask\n", - " finished = flag\n", - " print(unfinished)\n", - " print(finished)\n", - " score.masked_fill_(unfinished, -float('inf'))\n", - " score.masked_fill_(finished, 0)\n", - " return score" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "agreed-portuguese", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[ True],\n", - " [False]])\n", - "tensor([[-0.8841, 0.7381, -0.9986],\n", - " [ 0.2675, -0.7971, 0.3798]])\n", - "tensor([[ True, True],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "score = torch.randn((2, 3))\n", - "flag = torch.ones((2, 1), dtype=torch.bool)\n", - "flag[1] = False\n", - "print(flag)\n", - "print(score)\n", - "print(flag.repeat([1, 2]))" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "clean-aspect", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[False, True, True],\n", - " [False, False, False]])\n", - "tensor([[ True, False, False],\n", - " [False, False, False]])\n", - "tensor([[ 0.0000, -inf, -inf],\n", - " [ 0.2675, -0.7971, 0.3798]])\n", - "tensor([[ 0.0000, -inf, -inf],\n", - " [ 0.2675, -0.7971, 0.3798]])\n" - ] - } - ], - "source": [ - "r = mask_finished_scores(score, flag)\n", - "print(r)\n", - "print(score)" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "id": "thrown-airline", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[2, 1], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True ],\n", - " [False]])\n", - "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, 1.87704289, 0.01988174],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , True ],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "import paddle\n", - "\n", - "score = paddle.randn((2, 3))\n", - "flag = paddle.ones((2, 1), dtype='bool')\n", - "flag[1] = False\n", - "print(flag)\n", - "print(score)\n", - "print(flag.tile([1, 2]))" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "internal-patent", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[False, True , True ],\n", - " [False, False, False]])\n", - "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , False, False],\n", - " [False, False, False]])\n", - "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, 1.87704289, 0.01988174],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, 1.87704289, 0.01988174],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 2.05994511, -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 0. , -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n", - "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[ 0. , -inf. , -inf. ],\n", - " [-0.40165186, 0.77547729, -0.64469045]])\n" - ] - } - ], - "source": [ - "paddle.bool = 'bool'\n", - "\n", - "def masked_fill(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n", - " print(xs)\n", - " trues = paddle.ones_like(xs) * value\n", - " assert xs.shape == mask.shape\n", - " xs = paddle.where(mask, trues, xs)\n", - " return xs\n", - "\n", - "def masked_fill_(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n", - " print('x', xs)\n", - " trues = paddle.ones_like(xs) * value\n", - " assert xs.shape == mask.shape\n", - " ret = paddle.where(mask, trues, xs)\n", - " print('2', xs)\n", - " paddle.assign(ret, output=xs)\n", - " print('3', xs)\n", - "\n", - "paddle.Tensor.masked_fill = masked_fill\n", - "paddle.Tensor.masked_fill_ = masked_fill_\n", - "\n", - "def mask_finished_scores_pd(score: paddle.Tensor,\n", - " flag: paddle.Tensor) -> paddle.Tensor:\n", - " \"\"\"\n", - " If a sequence is finished, we only allow one alive branch. This function\n", - " aims to give one branch a zero score and the rest -inf score.\n", - " Args:\n", - " score (torch.Tensor): A real value array with shape\n", - " (batch_size * beam_size, beam_size).\n", - " flag (torch.Tensor): A bool array with shape\n", - " (batch_size * beam_size, 1).\n", - " Returns:\n", - " torch.Tensor: (batch_size * beam_size, beam_size).\n", - " \"\"\"\n", - " beam_size = score.shape[-1]\n", - " zero_mask = paddle.zeros_like(flag, dtype=paddle.bool)\n", - " if beam_size > 1:\n", - " unfinished = paddle.concat((zero_mask, flag.tile([1, beam_size - 1])),\n", - " axis=1)\n", - " finished = paddle.concat((flag, zero_mask.tile([1, beam_size - 1])),\n", - " axis=1)\n", - " else:\n", - " unfinished = zero_mask\n", - " finished = flag\n", - " print(unfinished)\n", - " print(finished)\n", - " \n", - " #score.masked_fill_(unfinished, -float('inf'))\n", - " #score.masked_fill_(finished, 0)\n", - "# infs = paddle.ones_like(score) * -float('inf')\n", - "# score = paddle.where(unfinished, infs, score)\n", - "# score = paddle.where(finished, paddle.zeros_like(score), score)\n", - "\n", - "# score = score.masked_fill(unfinished, -float('inf'))\n", - "# score = score.masked_fill(finished, 0)\n", - " score.masked_fill_(unfinished, -float('inf'))\n", - " score.masked_fill_(finished, 0)\n", - " return score\n", - "\n", - "r = mask_finished_scores_pd(score, flag)\n", - "print(r)" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "vocal-prime", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 57, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "score.value" - ] - }, - { - "cell_type": "code", - "execution_count": 71, - "id": "bacterial-adolescent", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Union, Any" - ] - }, - { - "cell_type": "code", - "execution_count": 72, - "id": "absent-fiber", - "metadata": {}, - "outputs": [], - "source": [ - "def repeat(xs : paddle.Tensor, *size: Any):\n", - " print(size)\n", - " return paddle.tile(xs, size)\n", - "paddle.Tensor.repeat = repeat" - ] - }, - { - "cell_type": "code", - "execution_count": 73, - "id": "material-harbor", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 2)\n", - "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , True ],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "flag = paddle.ones((2, 1), dtype='bool')\n", - "flag[1] = False\n", - "print(flag.repeat(1, 2))" - ] - }, - { - "cell_type": "code", - "execution_count": 84, - "id": "acute-brighton", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [1]), 2)\n", - "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[True , True ],\n", - " [False, False]])\n" - ] - } - ], - "source": [ - "flag = paddle.ones((2, 1), dtype='bool')\n", - "flag[1] = False\n", - "print(flag.repeat(paddle.to_tensor(1), 2))" - ] - }, - { - "cell_type": "code", - "execution_count": 85, - "id": "european-rugby", - "metadata": {}, - "outputs": [], - "source": [ - "def size(xs, *args: int):\n", - " nargs = len(args)\n", - " s = paddle.shape(xs)\n", - " assert(nargs <= 1)\n", - " if nargs == 1:\n", - " return s[args[0]]\n", - " else:\n", - " return s\n", - "paddle.Tensor.size = size" - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "id": "moral-special", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[2], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [2, 1])" - ] - }, - "execution_count": 86, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "flag.size()" - ] - }, - { - "cell_type": "code", - "execution_count": 87, - "id": "ahead-coach", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [1])" - ] - }, - "execution_count": 87, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "flag.size(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 88, - "id": "incomplete-fitness", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n", - " [2])" - ] - }, - "execution_count": 88, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "flag.size(0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "upset-connectivity", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/position_embeding_check.ipynb b/.notebook/position_embeding_check.ipynb deleted file mode 100644 index d4b9098d9..000000000 --- a/.notebook/position_embeding_check.ipynb +++ /dev/null @@ -1,231 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "id": "designing-borough", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n", - " 0.0000000e+00 0.0000000e+00]\n", - " [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n", - " 1.1547816e-04 1.0746076e-04]\n", - " [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n", - " 2.3095631e-04 2.1492151e-04]\n", - " ...\n", - " [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n", - " 1.1201146e-02 1.0423505e-02]\n", - " [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n", - " 1.1316618e-02 1.0530960e-02]\n", - " [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n", - " 1.1432089e-02 1.0638415e-02]]\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "import torch\n", - "import math\n", - "import numpy as np\n", - "\n", - "max_len=100\n", - "d_model=256\n", - "\n", - "pe = torch.zeros(max_len, d_model)\n", - "position = torch.arange(0, max_len,\n", - " dtype=torch.float32).unsqueeze(1)\n", - "toruch_position = position\n", - "div_term = torch.exp(\n", - " torch.arange(0, d_model, 2, dtype=torch.float32) *\n", - " -(math.log(10000.0) / d_model))\n", - "tourch_div_term = div_term.cpu().detach().numpy()\n", - "\n", - "\n", - "\n", - "torhc_sin = torch.sin(position * div_term)\n", - "torhc_cos = torch.cos(position * div_term)\n", - "print(torhc_sin.cpu().detach().numpy())\n", - "np_sin = np.sin((position * div_term).cpu().detach().numpy())\n", - "np_cos = np.cos((position * div_term).cpu().detach().numpy())\n", - "print(np.allclose(np_sin, torhc_sin.cpu().detach().numpy()))\n", - "print(np.allclose(np_cos, torhc_cos.cpu().detach().numpy()))\n", - "pe[:, 0::2] = torhc_sin\n", - "pe[:, 1::2] = torhc_cos\n", - "tourch_pe = pe.cpu().detach().numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "swiss-referral", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "False\n", - "False\n", - "False\n", - "False\n", - "[[ 1. 1. 1. ... 1. 1.\n", - " 1. ]\n", - " [ 0.5403023 0.59737533 0.6479059 ... 1. 1.\n", - " 1. ]\n", - " [-0.41614684 -0.28628543 -0.1604359 ... 0.99999994 1.\n", - " 1. ]\n", - " ...\n", - " [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.99993724\n", - " 0.9999457 ]\n", - " [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n", - " 0.99994457]\n", - " [ 0.03982088 -0.52298605 -0.6157435 ... 0.99992454 0.9999347\n", - " 0.99994344]]\n", - "----\n", - "[[ 1. 1. 1. ... 1. 1.\n", - " 1. ]\n", - " [ 0.54030234 0.59737533 0.6479059 ... 1. 1.\n", - " 1. ]\n", - " [-0.41614684 -0.28628543 -0.1604359 ... 1. 1.\n", - " 1. ]\n", - " ...\n", - " [-0.92514753 -0.66694194 -0.67894876 ... 0.9999276 0.9999373\n", - " 0.9999457 ]\n", - " [-0.81928825 -0.9959641 -0.999139 ... 0.99992603 0.999936\n", - " 0.99994457]\n", - " [ 0.03982088 -0.5229861 -0.6157435 ... 0.99992454 0.9999347\n", - " 0.99994344]]\n", - ")))))))\n", - "[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n", - " 0.0000000e+00 0.0000000e+00]\n", - " [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n", - " 1.1547816e-04 1.0746076e-04]\n", - " [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n", - " 2.3095631e-04 2.1492151e-04]\n", - " ...\n", - " [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n", - " 1.1201146e-02 1.0423505e-02]\n", - " [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n", - " 1.1316618e-02 1.0530960e-02]\n", - " [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n", - " 1.1432089e-02 1.0638415e-02]]\n", - "----\n", - "[[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00\n", - " 0.0000000e+00 0.0000000e+00]\n", - " [ 8.4147096e-01 8.0196178e-01 7.6172036e-01 ... 1.2409373e-04\n", - " 1.1547816e-04 1.0746076e-04]\n", - " [ 9.0929741e-01 9.5814437e-01 9.8704624e-01 ... 2.4818745e-04\n", - " 2.3095631e-04 2.1492151e-04]\n", - " ...\n", - " [ 3.7960774e-01 7.4510968e-01 7.3418564e-01 ... 1.2036801e-02\n", - " 1.1201146e-02 1.0423505e-02]\n", - " [-5.7338190e-01 -8.9752287e-02 -4.1488394e-02 ... 1.2160885e-02\n", - " 1.1316618e-02 1.0530960e-02]\n", - " [-9.9920684e-01 -8.5234123e-01 -7.8794664e-01 ... 1.2284970e-02\n", - " 1.1432089e-02 1.0638415e-02]]\n" - ] - } - ], - "source": [ - "import paddle\n", - "paddle.set_device('cpu')\n", - "ppe = paddle.zeros((max_len, d_model), dtype='float32')\n", - "position = paddle.arange(0, max_len,\n", - " dtype='float32').unsqueeze(1)\n", - "print(np.allclose(position.numpy(), toruch_position))\n", - "div_term = paddle.exp(\n", - " paddle.arange(0, d_model, 2, dtype='float32') *\n", - " -(math.log(10000.0) / d_model))\n", - "print(np.allclose(div_term.numpy(), tourch_div_term))\n", - "\n", - "\n", - "\n", - "p_sin = paddle.sin(position * div_term)\n", - "p_cos = paddle.cos(position * div_term)\n", - "print(np.allclose(np_sin, p_sin.numpy(), rtol=1.e-6, atol=0))\n", - "print(np.allclose(np_cos, p_cos.numpy(), rtol=1.e-6, atol=0))\n", - "ppe[:, 0::2] = p_sin\n", - "ppe[:, 1::2] = p_cos\n", - "print(np.allclose(p_sin.numpy(), torhc_sin.cpu().detach().numpy()))\n", - "print(np.allclose(p_cos.numpy(), torhc_cos.cpu().detach().numpy()))\n", - "print(p_cos.numpy())\n", - "print(\"----\")\n", - "print(torhc_cos.cpu().detach().numpy())\n", - "print(\")))))))\")\n", - "print(p_sin.numpy())\n", - "print(\"----\")\n", - "print(torhc_sin.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "integrated-boards", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "False\n" - ] - } - ], - "source": [ - "print(np.allclose(ppe.numpy(), pe.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "flying-reserve", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "revised-divide", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/python_test.ipynb b/.notebook/python_test.ipynb deleted file mode 100644 index 819d4c48f..000000000 --- a/.notebook/python_test.ipynb +++ /dev/null @@ -1,1680 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "choice-lender", - "metadata": {}, - "outputs": [], - "source": [ - "eng=\"one minute a voice said and the time buzzer sounded\"\n", - "chn=\"可控是病毒武器最基本的要求\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "ruled-kuwait", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "o\n", - "n\n", - "e\n", - " \n", - "m\n", - "i\n", - "n\n", - "u\n", - "t\n", - "e\n", - " \n", - "a\n", - " \n", - "v\n", - "o\n", - "i\n", - "c\n", - "e\n", - " \n", - "s\n", - "a\n", - "i\n", - "d\n", - " \n", - "a\n", - "n\n", - "d\n", - " \n", - "t\n", - "h\n", - "e\n", - " \n", - "t\n", - "i\n", - "m\n", - "e\n", - " \n", - "b\n", - "u\n", - "z\n", - "z\n", - "e\n", - "r\n", - " \n", - "s\n", - "o\n", - "u\n", - "n\n", - "d\n", - "e\n", - "d\n" - ] - } - ], - "source": [ - "for char in eng:\n", - " print(char)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "passive-petite", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "可\n", - "控\n", - "是\n", - "病\n", - "毒\n", - "武\n", - "器\n", - "最\n", - "基\n", - "本\n", - "的\n", - "要\n", - "求\n" - ] - } - ], - "source": [ - "for char in chn:\n", - " print(char)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "olympic-realtor", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "one\n", - "minute\n", - "a\n", - "voice\n", - "said\n", - "and\n", - "the\n", - "time\n", - "buzzer\n", - "sounded\n" - ] - } - ], - "source": [ - "for word in eng.split():\n", - " print(word)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "induced-enhancement", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "可控是病毒武器最基本的要求\n" - ] - } - ], - "source": [ - "for word in chn.split():\n", - " print(word)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "lovely-bottle", - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'StringIO'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mStringIO\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'StringIO'" - ] - } - ], - "source": [ - "import StringIO" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "interested-cardiff", - "metadata": {}, - "outputs": [], - "source": [ - "from io import StringIO" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "portable-ivory", - "metadata": {}, - "outputs": [], - "source": [ - "inputs = StringIO()" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "compatible-destination", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "64" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "inputs.write(\"nor is mister quilter's manner less interesting than his matter\" + '\\n')" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "federal-margin", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nor is mister quilter's manner less interesting than his matternor is mister quilter's manner less interesting than his matter\n", - "\n" - ] - } - ], - "source": [ - "print(inputs.getvalue())" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "consecutive-entity", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "64" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "inputs.write(\"nor is mister quilter's manner less interesting than his matter\" + '\\n')" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "desirable-anxiety", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nor is mister quilter's manner less interesting than his matternor is mister quilter's manner less interesting than his matter\n", - "nor is mister quilter's manner less interesting than his matter\n", - "\n" - ] - } - ], - "source": [ - "print(inputs.getvalue())" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "employed-schedule", - "metadata": {}, - "outputs": [], - "source": [ - "import tempfile" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "unlikely-honduras", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['__class__', '__del__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__ne__', '__new__', '__next__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_checkClosed', '_checkReadable', '_checkSeekable', '_checkWritable', '_dealloc_warn', '_finalizing', 'close', 'closed', 'detach', 'fileno', 'flush', 'isatty', 'mode', 'name', 'peek', 'raw', 'read', 'read1', 'readable', 'readinto', 'readinto1', 'readline', 'readlines', 'seek', 'seekable', 'tell', 'truncate', 'writable', 'write', 'writelines']\n", - "57\n" - ] - } - ], - "source": [ - "with tempfile.TemporaryFile() as fp:\n", - " print(dir(fp))\n", - " print(fp.name)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "needed-trail", - "metadata": {}, - "outputs": [], - "source": [ - "a = tempfile.mkstemp(suffix=None, prefix='test', dir=None, text=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "hazardous-choir", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['__add__', '__class__', '__contains__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getnewargs__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__mul__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__rmul__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'count', 'index']\n" - ] - } - ], - "source": [ - "print(dir(a))" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "front-sauce", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(57, '/tmp/test27smzbzc')\n" - ] - } - ], - "source": [ - "print(a)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "shared-wages", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "print(a.index)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "charged-carnival", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_closer', 'close', 'delete', 'file', 'name']\n", - "/tmp/tmpfjn7mygy\n" - ] - } - ], - "source": [ - "fp= tempfile.NamedTemporaryFile(mode='w', delete=False)\n", - "print(dir(fp))\n", - "print(fp.name)\n", - "fp.close()" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "religious-terror", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/tmp/tmpfjn7mygy\n" - ] - } - ], - "source": [ - "import os\n", - "os.path.exists(fp.name)\n", - "print(fp.name)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "communist-gospel", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fp.write" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "simplified-clarity", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'example'" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "s='/home/ubuntu/python/example.py'\n", - "os.path.splitext(os.path.basename(s))[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "popular-genius", - "metadata": {}, - "outputs": [], - "source": [ - "from collections import Counter" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "studied-burner", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_items([('hello', 1), ('world', 1)])\n" - ] - } - ], - "source": [ - "counter = Counter()\n", - "counter.update([\"hello\"])\n", - "counter.update([\"world\"])\n", - "print(counter.items())" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "mineral-ceremony", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_items([('h', 1), ('e', 1), ('l', 3), ('o', 2), ('w', 1), ('r', 1), ('d', 1)])\n" - ] - } - ], - "source": [ - "counter = Counter()\n", - "counter.update(\"hello\")\n", - "counter.update(\"world\")\n", - "print(counter.items())" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "nonprofit-freedom", - "metadata": {}, - "outputs": [], - "source": [ - "counter.update(list(\"hello\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "extended-methodology", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dict_items([('h', 2), ('e', 2), ('l', 5), ('o', 3), ('w', 1), ('r', 1), ('d', 1)])\n" - ] - } - ], - "source": [ - "print(counter.items())" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "grand-benjamin", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['h', 'e', 'l', 'l', 'o']" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "list(\"hello\")" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "id": "marine-fundamentals", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{}\n" - ] - } - ], - "source": [ - "from io import StringIO\n", - "a = StringIO(initial_value='{}', newline='')\n", - "print(a.read())" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "suitable-charlotte", - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "expected str, bytes or os.PathLike object, not _io.StringIO", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mTypeError\u001b[0m: expected str, bytes or os.PathLike object, not _io.StringIO" - ] - } - ], - "source": [ - "with io.open(a) as f:\n", - " print(f.read())" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "institutional-configuration", - "metadata": {}, - "outputs": [], - "source": [ - "io.open?" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "pregnant-modem", - "metadata": {}, - "outputs": [], - "source": [ - "def get_default_args(fn):\n", - " if fn is None:\n", - " return {}\n", - "\n", - " signature = inspect.signature(fn)\n", - " return {\n", - " k: v.default\n", - " for k, v in signature.parameters.items()\n", - " if v.default is not inspect.Parameter.empty\n", - " }" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "id": "first-release", - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'inspect' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_default_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m\u001b[0m in \u001b[0;36mget_default_args\u001b[0;34m(fn)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0msignature\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minspect\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msignature\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m return {\n\u001b[1;32m 7\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefault\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'inspect' is not defined" - ] - } - ], - "source": [ - "get_default_args(io.open)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "convertible-roulette", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: sox in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (1.4.1)\n", - "Requirement already satisfied: numpy>=1.9.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from sox) (1.20.1)\n", - "Requirement already satisfied: librosa in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (0.8.0)\n", - "Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.24.1)\n", - "Requirement already satisfied: numba>=0.43.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.52.0)\n", - "Requirement already satisfied: pooch>=1.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.3.0)\n", - "Requirement already satisfied: scipy>=1.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.2.1)\n", - "Requirement already satisfied: numpy>=1.15.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.20.1)\n", - "Requirement already satisfied: decorator>=3.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (4.4.2)\n", - "Requirement already satisfied: resampy>=0.2.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.2.2)\n", - "Requirement already satisfied: audioread>=2.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (2.1.9)\n", - "Requirement already satisfied: soundfile>=0.9.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (0.9.0.post1)\n", - "Requirement already satisfied: joblib>=0.14 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from librosa) (1.0.1)\n", - "Requirement already satisfied: setuptools in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from numba>=0.43.0->librosa) (51.0.0)\n", - "Requirement already satisfied: llvmlite<0.36,>=0.35.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from numba>=0.43.0->librosa) (0.35.0)\n", - "Requirement already satisfied: appdirs in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (1.4.4)\n", - "Requirement already satisfied: packaging in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (20.9)\n", - "Requirement already satisfied: requests in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from pooch>=1.0->librosa) (2.25.1)\n", - "Requirement already satisfied: six>=1.3 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from resampy>=0.2.2->librosa) (1.15.0)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa) (2.1.0)\n", - "Requirement already satisfied: cffi>=0.6 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from soundfile>=0.9.0->librosa) (1.14.4)\n", - "Requirement already satisfied: pycparser in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from cffi>=0.6->soundfile>=0.9.0->librosa) (2.20)\n", - "Requirement already satisfied: pyparsing>=2.0.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from packaging->pooch>=1.0->librosa) (2.4.7)\n", - "Requirement already satisfied: idna<3,>=2.5 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (2.10)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (2020.12.5)\n", - "Requirement already satisfied: chardet<5,>=3.0.2 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (4.0.0)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from requests->pooch>=1.0->librosa) (1.26.3)\n" - ] - } - ], - "source": [ - "!pip install sox\n", - "!pip install librosa" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "cutting-fleece", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "import sox\n", - "tfm = sox.Transformer()\n", - "sample_rate = 44100\n", - "y = np.sin(2 * np.pi * 440.0 * np.arange(sample_rate * 1.0) / sample_rate)\n", - "print(y.dtype.type)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "historical-diving", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.06264832 0.12505052 ... -0.18696144 -0.12505052\n", - " -0.06264832]\n" - ] - } - ], - "source": [ - "output_array = tfm.build_array(input_array=y, sample_rate_in=sample_rate)\n", - "print(output_array)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "similar-spice", - "metadata": {}, - "outputs": [], - "source": [ - "tfm.build_array?" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "grand-influence", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['8svx', 'aif', 'aifc', 'aiff', 'aiffc', 'al', 'amb', 'amr-nb', 'amr-wb', 'anb', 'au', 'avr', 'awb', 'caf', 'cdda', 'cdr', 'cvs', 'cvsd', 'cvu', 'dat', 'dvms', 'f32', 'f4', 'f64', 'f8', 'fap', 'flac', 'fssd', 'gsm', 'gsrt', 'hcom', 'htk', 'ima', 'ircam', 'la', 'lpc', 'lpc10', 'lu', 'mat', 'mat4', 'mat5', 'maud', 'nist', 'ogg', 'paf', 'prc', 'pvf', 'raw', 's1', 's16', 's2', 's24', 's3', 's32', 's4', 's8', 'sb', 'sd2', 'sds', 'sf', 'sl', 'sln', 'smp', 'snd', 'sndfile', 'sndr', 'sndt', 'sou', 'sox', 'sph', 'sw', 'txw', 'u1', 'u16', 'u2', 'u24', 'u3', 'u32', 'u4', 'u8', 'ub', 'ul', 'uw', 'vms', 'voc', 'vorbis', 'vox', 'w64', 'wav', 'wavpcm', 'wv', 'wve', 'xa', 'xi']\n" - ] - } - ], - "source": [ - "print(sox.core._get_valid_formats())" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "wireless-hypothetical", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float64\n", - "(59471,)\n", - "16000\n", - "(54065,)\n", - "1.0999907518727459\n" - ] - } - ], - "source": [ - "import soundfile as sf\n", - "wav='/workspace/DeepSpeech-2.x/examples/aishell/s1/../../..//examples/dataset/aishell/data_aishell/wav/dev/S0724/BAC009S0724W0190.wav'\n", - "samples, sr = sf.read(wav)\n", - "print(samples.dtype)\n", - "print(samples.shape)\n", - "print(sr)\n", - "tfm = sox.Transformer()\n", - "tfm.speed(1.1)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "output_array.dtype\n", - "print(output_array.shape)\n", - "print(len(samples)/len(output_array))" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "designed-fluid", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import IPython.display as ipd\n", - "ipd.Audio(wav) # load a local WAV file" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "cultural-friendship", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfm = sox.Transformer()\n", - "tfm.speed(1.0)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "ipd.Audio(output_array, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "fossil-lotus", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfm = sox.Transformer()\n", - "tfm.speed(1.1)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "ipd.Audio(output_array, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "constitutional-poker", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfm = sox.Transformer()\n", - "tfm.speed(0.9)\n", - "output_array = tfm.build_array(input_array=samples, sample_rate_in=sr)\n", - "ipd.Audio(output_array, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "threaded-strap", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "66078\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8K0lEQVR4nO2dd3hUZfbHvycdQoAEQpEWmlQVJICoKApqABdcF8u6KlbUXX+77rqrIFZs7Lr2si5WXHXtrigI0myoSFB67xDpoYQEUs/vj7kTJpM7M/fO7XPP53nmye33ZObe97zveU8hZoYgCILgX5KcFkAQBEFwFlEEgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPscURUBEBUS0log2ENF4lf1/IaJVRLSMiOYSUYeQfWOJaL3yGWuGPIIgCIJ2yGgcARElA1gH4DwAOwAsAvBbZl4Vcsw5ABYycxkR3QJgCDNfRkQ5AAoB5ANgAIsB9GPmA4aEEgRBEDRjxohgAIANzLyJmSsAvANgdOgBzDyfmcuU1R8AtFWWLwAwm5mLlcZ/NoACE2QSBEEQNJJiwjXaANgesr4DwMAox18P4PMo57aJdcPmzZtzXl6ePikFQRB8zuLFi/cxc274djMUgWaI6EoEzEBnx3HuOADjAKB9+/YoLCw0WTpBEITEhoi2qm03wzRUBKBdyHpbZVu4AMMATAQwipnL9ZwLAMw8hZnzmTk/N7eeQhMEQRDixAxFsAhAVyLqSERpAC4HMC30ACLqC+DfCCiBPSG7ZgE4n4iyiSgbwPnKNkEQBMEmDJuGmLmKiG5FoAFPBvAqM68kokkACpl5GoDHADQC8D4RAcA2Zh7FzMVE9CACygQAJjFzsVGZBEEQBO0Ydh91gvz8fJY5AkEQBH0Q0WJmzg/fLpHFgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgif42/tLsXirpKASBCuwNbJYEOLl/cU7kJ6ahH4dsp0WRRASDhkRCIKQcKwoOoSZK3Y5LYZnEEUgCELCMeGj5bj5zcVOi+EZRBEInsGDsY+CQzDkYdGDKALBM8irLWhFOg36EEUgCELCIYpAH6IIBEFIOEQP6EMUgeAZ3l64zWkRBI/gxWSaTiKKQBAEweeIIhAEIaGorK7Bml0lTovhKUQRCIKQUBworXBaBM8hikAQhISiRqYHdCOKQPAUFVU1OPHuz50WQ3AxNTJRrBtTFAERFRDRWiLaQETjVfafRUQ/EVEVEY0J21dNREuUz7TwcwUhlKMV1aioqnFaDMHFiCLQj+Hso0SUDOB5AOcB2AFgERFNY+ZVIYdtA3ANgL+qXOIoM/cxKofgE8hpAQS3UyP9BN2YMSIYAGADM29i5goA7wAYHXoAM29h5mUA5CcSdDFj+U7kjZ/utBiCh5ARgX7MUARtAGwPWd+hbNNKBhEVEtEPRHSRCfIkNHnjp2P/kXKnxbCNtWFugEkyIhBiMPKZb5wWwXO4YbK4AzPnA7gCwFNE1FntICIapyiMwr1799orocvYU+IfRRDetyMSTSBEp7SiunZ5n486TUYwQxEUAWgXst5W2aYJZi5S/m4C8CWAvhGOm8LM+cycn5ubG7+0HmD+mj1Ysv1gve3HKgMPeLWf/ONkmC8YYM9hUQRaMEMRLALQlYg6ElEagMsBaPL+IaJsIkpXlpsDOAPAquhnJT7Xvr4If3jrp3rb/zFzLQB/KYLw/3Tu6t2OyCF4E5kv0IZhRcDMVQBuBTALwGoA7zHzSiKaRESjAICI+hPRDgCXAPg3Ea1UTu8BoJCIlgKYD2BymLeRr8kbP73OROn+0kDv5kCZfyMn//TOEgCSVEzQhjwm2jCleD0zzwAwI2zbvSHLixAwGYWf9x2Ak8yQIdFQM4UHH+oHPl2FId1a2CuQQ0R6kcurapCRmmyvMB6kpoZRw4yUZDdMB9qPjAi04c+nw8UcVHr7as9vsZJDZfO+UjtFcpRIJQeHPfGVzZJ4k7/PWoPu98x0WgzH+GHTfqdF8ASiCFzG8qJDqts37DmCbzfss1ka54nUodtx4Ki9gniUVb8cRpWP5pQ+WVLXT0UUgTZEEXiEV77d7LQIjuCfJsxanGgQDx+rxLIdB2295+qddeNOkpOkidOCfEseoVri5gUDXD7lB9vv+cQX6zDquQW23zcUn06N6Ea+Jo/gp+F9KDLXZ4yFm4odu3eVCzovBb1bOS2CJxBF4BFq/KoIxDhkiIpq5xrjJAeiwMOfl7Rk8SzTgigCl1J0sO5kaLWG9rDkWCXKKqosksghRA94FicUgRAfoghcRiRTiJYRwbAnvnLEFmwloge8i+gB72BKQJlgPdOX74x5zO7D5ThyLMFGBIJnkRGBd5ARgcsw+u5UJ9jsqqSS8C4rIsTEWMl7i7bHPkiohygCB2FmVJk8mZdoCelED3iTWSt3YeFm+z2WDpRV1lkXZwNtiCJwkBe+3IguE80txJ5oikDwJou3HqhdPmPyPAclEbQgisBBPltW3+4f+gLFQ6LpgQT7d3xDqIUz3ANOcB8yWewgq3cerrftqTnrHZBEEMzjdy//gAUb3JHjp0hyUmlCRgSCq5E5Am9xoLTCNUoAAB79fA0A4L8/bkNpuXjURUIUgUuoqq7Bawu0J5Yrr6qus/7yN5vMFsn1/PX9pfhmvb/rVwPA7sPHsOvQMafFAAAcrayOfZADTPhoOeau2eO0GK5FFIFL2LK/FA98qr0427pdR+qsPzR9tdkiuZ4PFu/AB4t3OC2G4/z6+QU4+7H5TosRFSkx6m5EEbgE0hlA8Kvnvq1dTuSc6+L+F5tjVTUor6rvhjxzxS7bZYn0GF8/tdBeQQRdmKIIiKiAiNYS0QYiGq+y/ywi+omIqohoTNi+sUS0XvmMNUMerzH6uW9hJI7s2tcWmSaL23htwZaYx0xfthMb9hyJeZzfePyLtbbfkww9yYJTGFYERJQM4HkAwwH0BPBbIuoZdtg2ANcAeDvs3BwA9wEYCGAAgPuIKNuoTF5j6Y5DukcEwPGo20oHM0y6gT+8/RMeneE/01iQSHmoJMODoBUzRgQDAGxg5k3MXAHgHQCjQw9g5i3MvAxAeIt1AYDZzFzMzAcAzAZQYIJMniPaOxvpRT+s5BXya60C4Pj35tdGr7i0AgePVsY+UBCiYIYiaAMgNMHHDmWb1ecmFNESdHW6a4b6DpZcPMf/e/9ogsv+/T0qlDmBaC6RTjwaTipkv9bsMAPPTBYT0TgiKiSiwr17ve8yOPbVH+usx/MCMRiLthiLRE4U9h4pd1oES7lhaiEOllUAABZuLsbhY4FRQFKSuxSgk9IcijAyyhs/3WZJvIcZiqAIQLuQ9bbKNlPPZeYpzJzPzPm5ublxCeomvlpnXJkxu6McoJN8suQXAMDS7QedFcRi5qzejeUh2TyDvX09De/KXw4ldKMo44H4MUMRLALQlYg6ElEagMsBTNN47iwA5xNRtjJJfL6yzXfENyIQL40gLusYW8LRiuPBWkG3Wj05/x/6rO6E+okTP8cvCZQH6JVvowdV+t2MGg3DioCZqwDcikADvhrAe8y8kogmEdEoACCi/kS0A8AlAP5NRCuVc4sBPIiAMlkEYJKyzXfE6zXk10nScOL5/rwGA9hbUn58BdE7EOH7UlPqvu4V1TXYsr/UPAEd5sWv/BddbxamzBEw8wxmPpGZOzPzw8q2e5l5mrK8iJnbMnMmMzdj5l4h577KzF2Uz2tmyONF4mnGGJKLx2/0f3gOAODCZwMBhXqem68Vc+ScVbstm1j9YlXkCGIr3ZwPllVICnYDeGayONFRy0QaC2ZgW3H9Hp0f4wqqaxj7E3zCOJQ9JeXIGz8dO+PIMXTDG4VY8Ysy32By23n3/1ZE3Ldwk3WD/QofPvNmIorAJcQTgs9g3Pnh8nrbS3xat7jfQ3NwJIEzTKqN/lbF0YEArIk9eXZu9BTqKckWmu9kMGAIUQQeJpJZyM+eREcSUAlGGy2GNq16TCMXv/CdAYnUeXz2uqj7Uy1UBFpqdYsZNTKiCDxMpBffz7bSRExSN/zpbwAAN7+5uN6+X0JMQ+EeQJo8ymycY09Jsq65qapOvN/dTkQReJhIj77VL8XmfaU49/EvLb2HoI1nQswxySE+tC99vUlTL9lORj+/wLJra/lXE7GTYBaiCDxMJM+PvUfK6/icm83S7Qexaa873Q79Oj8C1B0JPjxjta8mz2tcpvS8higCBzArsGXwP9SLkVz8wnf44zs/m3IPNYI9q2U7Dlp2j3g5/8mvbb/n/iPl2HnI+cCscM8ZTSbCBGk/tfwbEnwZGVEEDvB+ofVVtfYctq50YVCPXfHSQsvu4SUum/IDBj06z2kx6plH/NRJ1jIiCDUN7Tl8DP+cZX+9BrciisABNu61vohKekqyZdcOvnNHyqtw+FhlbQI0vxJMBuc04SPNkjhcafcfKa/NbBqNPSXHcOGz3+i+vlXoVXqzVu7Cc/M3WCOMBxFF4AQ2jFDTUqz5af/15UY8FtKTuuj5BRj5tHsaBD9jxgCg30Nz8Pjs2D3ldbuOYEXRYcx3SUF4LebW0EN8NFjShCgCB7DDVmmFz/aew8fwxOy12BVidtq0txTbDzhvHw+lyvYoU3fYnoMN3UOfrTJ0nc0aHAGCHkrXvu6OMqnSsBtDFIEDvPjVRsvvkZps/k874JG5qPSAv/Zpj841fI1v1+/DnCh5c9xI0Ab+2bKdms+pZq430f39pv0xz7M0SjgONM0RuP/RdQxRBAlKsh/yMkdg3xFjNvtjldW48pWFuOENbWk/gpHcTuf6DwaU6/GXf+P7rXFNdLvt8dLSyIcqC1EKdRFFkKDYnZX5UFklikvdMWlqlEgpHf73cxG2F5fV2ZY3fjoOlrljsjyoAPQElu+LM9bAbQ2plhFBoVLNb8n2g1KbIAxRBIIpjHnxO5zzzy+dFsMUfh0hD89t7y7B8Ke/qW1EIjWi63eXOJLmI9i21dYs0EC8YrqtGdXSrr9bGCiPftHzC7B2t/Wee15CFEGCYnehlvV7jkSsGZtIHCmvqv0/Rz+nnjLhvCe/xmfLfjHlfnrNTXoVULw943g71Fa5TuuVp9rHiRnVEEVgM1p8tM3AZSZcz3Kssn6qji9W7cayHQexN4pZZXtxme3eSzXMuhv2eEcu8SqQZIs6KHrzCM1zidurW0hxWgC/cfZj6mkh3M6bP2x1WgRHmLbkF1zav12dbXd8sCzmef/8Yh3++UUgLfO8289Gp9xGlsgXCrN+k42aHtDiaBCvaSj82oFyq8aVQ2m5vtxaRh0KEg1TRgREVEBEa4loAxGNV9mfTkTvKvsXElGesj2PiI4S0RLl86IZ8riZeCpKxYOeouZaWLBhX8xj3ORSmDd+OraaUI+3vEpfA7NbJbXHLpt+83gaZ7XEhdecnhf7PBMmW//435/RccIMw9cBgN++9IMp1/ErhhUBESUDeB7AcAA9AfyWiHqGHXY9gAPM3AXAkwD+HrJvIzP3UT43G5VHCGD2CFzL9aqqGTsOlMU+0CYueMp4AroXv9qky3wy8JH6MQzlNpmInpu3Xr+tXOWEBqmx05METexNG6YC0G4qenjG6trlaUvNmUcRjGPGiGAAgA3MvImZKwC8A2B02DGjAUxVlj8AMJTsns0UbGFFUXylE63gWKXxBrjo4FHsKTHWo4+3ULxeO/yc1XtQVqEvv1BQyYUXtYl5niLbwbJArimtrsMzV+zSdR/BHsxQBG0AbA9Z36FsUz2GmasAHALQTNnXkYh+JqKviGiwCfIIMG+y+Ok565E3frpmO266RTmOnOSLlbt1N7ChBL+6eWt266oTce7jX+m+l16dEwyGGxXiAaXlpw5VbuWVNVEnztX4Zv3e2uUPF1ufjTfIT9sO2HYvL+H0W7sTQHtm7gvgLwDeJqLGagcS0TgiKiSiwr1796odIoRg1oDryTlKHVqNDcy/vrQ+fYaVbNtf37R137SVmLE8/p5sMLfUda8X4sOftDd6m/fpn+PQa7sPVrNT844KsmT7wXoeUKHmshpm3VXxrnrlx9rl299fqutcI1z1sqROV8MMRVAEINStoq2yTfUYIkoB0ATAfmYuZ+b9AMDMiwFsBHCi2k2YeQoz5zNzfm5urgli28/8tfa5rH38cxEmf77GtOuVa3R7/XFLsWn3NAO9LpwPfLpSdbtZIyyrA830KoKgPITjpqjwpIgXPb8A05fXzV8UepcaZs/UyRaLtDpmKIJFALoSUUciSgNwOYBpYcdMAzBWWR4DYB4zMxHlKpPNIKJOALoC2GSCTK5ES1ZHMzFzGOxESL4ZMRcHXJD+Yc2uktrl+6apKxqz0Psz7VGikEvKq2o7DkT1A79Ckw3OWbUbEz5aXrs+6NF5ltYjNpMjcdRo8AOGFYFi878VwCwAqwG8x8wriWgSEY1SDnsFQDMi2oCACSjoYnoWgGVEtASBSeSbmdldXUoPk2JiZjAn+nvvLtpm+BoPTzeWkjmIkY7k32fWHZlZ2Xs2cu3QrKNDH/8KRREmkN9auDXuHEWRsHO0LNTHlIAyZp4BYEbYtntDlo8BuETlvA8BfGiGDF7g5+0Hbb2fmcnQnJhkqzAh5fUeHXl3AHsUXg0zki2K/TaiCJbtOATguPkkkreTFeaVb9fvwzndWph+3Vh8sHgHxvRra/t93YZEFtvIpzb7Ta/epc2V81hlNVKSCClRahjoUSo7DpRh6ndbUF0D3Pur8JASe9HTZkWbMDXa9oUGmllpZRv13LeGrxEcSB4N+T5CzXRWqLDDDuWpWhMh06zfcNpryDfocRs0C2Zttv3THp2LP79X13PDSGTuut0leOXbzXh1wea4zg9i57TeoaOV6H7PTBRZVG0tNNDsCgujYM2YE3ng04A57faQZ+KRkEAwo7y1sH66khnLtRfTMROZOw4gisAmzAjJj4d/zFqLT5YURU1NfLCsEou3FGNEWO3hrSqulFogIlMK45jxkgZNJU/MXocNeyJnvgyaQdbuLlHdb2Z50cKt3vBlX150qHY5NysdAOpVM4uHiR+vqLet1IGOEmB+KhavIorAJpxSBKt3Hsaf3lmC/g/PiXrcL4eOYVXYMDleiaurzUkkZsYrGjRzPzN3PSZ8tCyiDf0HDeUZ7cSuLLWxCH5fF/RqBSDgITTXosydRmstx4O4kwYQRWATTqU//3KtvuC7UL/7eJVXZnpKxHTDI5/5RrOZzAzV2TgjtXZ50ZYDmLtavQ7xLW/9FPU6T89db4I0x4kVLOaW2g6/ejYw53CssrreiNFsXv7WmCkxHtxWctMpRBHYhFpyLzdSUV1TayapjtNr595PVkQ066z85TCKy7TlpTHjK5uzeje+XX88c2pVnF418UT5RiPWpG60iWs7CY4Sd6mMGBMBtyhcpxFFYBNujbwM9xXfsq+sdlu8jeb6PUdUba/BiWu7g9OufOV4WgG1uYvwOsR2EKuhj/e7F/Tx1sJtEWtU+wlRBDYRy0bvFGdMnldnfcQz39TWHq4wkD5ZLYIz2LZpNZO99p35poLMtPoe004M1oKms7zx01UnsZ3yoonEzJXHcy01zrDO6/z8J/Un2gO0FdOJxHCLTV5eQBSBDdz3SX0vCSdYExJXMH/tnohFV4K9UbMnLINzDlU1NZpGBduLzXflvOGNRfW2GVF48ZKcRLXRudtVajhMX+YuRRDK4WPWpWlYF2dReSdSoCQSoghsYOr37ijzOPW7LVinuEde+9oifLY0emNz1GQ7ddA8du7jX+EFh7KUqtUoqM2waiPJSVSraIN/3164Fa8qE6bSrGmnvKpad/ptoS6iCHzEf3/cjoKnvq5tkGOl/73nf8ZHMmUVVTh0tBJ546fjlAe+qN3+3LwNMc8deXJrw/ePRkVVDUqOVWKPSnlJqzl8rKr2d2jeKA3b9pfhro9XYJLiQtmmaYbtMrmFw8f0TeA+aILbad746Yav4WVEEVhMvNWprKKGgVIDRVb00vPeWfjzu0sA1E1lfUxDLeDcRumWyJQ3fjqKSysw/sNlOOn+L9CqSQNL7hOLexST4dpdR/C/JXUztw/u6s1U62Zw8v1f6HpvzPbo8iOSa8hiXvgyds/XbuzuAc9TCUA6s0tzW2UI59QHZ9cun9G5me15oIDjMR53fby83r5KB+Yt3ERVDSNNmQA+Ul6FBRv2ITcrHae2z6537Jqd6tHgejlUVokmDVNjH5iAiCKwmNW7zHlIzWTYE8aLuhulb/tsrN55GN1bZalGd57ywBe2+Xh/t9FdUcX//XEbSn2eNz9oNnt9wWbcr+Q+atIgFUvvO7/esfs11kuOxZ0fLsOLV/Uz5VpeQ0xDVuMuy5BrWLPzMIY//Q1mrdxdp3e+YMM+5I2fbmugzzQHRgPRmPDRcvy07aDTYjjK1uJSfL1ub60SUMNsZbm86BAu+/f3pl7TK4giEBzhi1WBVA83v7kYxaUVtUFdoYnO/IzfRwQFT32Dq1/9sc628FCBXvfNwnuLtqNBarIp9yw6eBQLNxdj0KOBTLHrd5dETdaYSIhpyGJ2OeCR4kUG/2M+crPSffPixcIrGUrtpKyiGoVbivHsvA14/NJTAAB3fLjM9PvsPHQMX6zchXH/WQwAmP3ns9C1ZRYAYPqyX3Bqh2y0dsjBwCrIi4EY+fn5XFhY6LQYmjj7H/Ox1YEUBoKQyJzYslHcwWfxsO6h4UhOInS+K1CIccvkkZbfc29JeW3678rqGqQmJ+FIeRUapcfffyeixcycH77dFNMQERUQ0Voi2kBE41X2pxPRu8r+hUSUF7JvgrJ9LRFdYIY8bkKUgCCYz+7D9o4c+z74Be6ftrLOttLyKqwwYMqsqWG8vXAbqmsYh8oq8dScdTikFBZiZvR/eA7W7y5BRVUNuk78HPPW7Ebv+2ZhT4n5VgbDioCIkgE8D2A4gJ4AfktE4fUJrwdwgJm7AHgSwN+Vc3sCuBxALwAFAF5QrqeL8qpqw+kQmFk1TL2sogoFT31de/2t+0s1ZYZk5noyScpbIRx5JOLD7qyhpeXV+M8PxzMElFdVo9d9s3ChkqZ7za7DWKc02sFU7rHaiU37juCuj5dj9c7DmPr9Fjw1Zz3+/fVG/LTtAK55LZAKZe+R8tpMtde9HrCC3PTGYny4eAfOffzLOterqKrBzBU7sWmv+kgp2ryTGXMEAwBsYOZNAEBE7wAYDSB0un80gPuV5Q8APEcBn8HRAN5h5nIAm4log3K9mFP3P24uxvbiMgzslIORz3wb0M4PXIAMZeLoaEU1UpPr1+FlDhRNYWbUMPDLwaNol9MQHSfMQIusdPw4cRiYGW8u3IaUJEJmegrW7CrBDW8U4ut1Ab/viSN6YGCnHMxcsQt/HNoV24vL0LVlFm57Zwn+t6QIn/3fmXh67nrMXlU3973LYssEFyCPhDfpdvfM2uW1u0pQ8NTxxHXDerTE0B4tMOGj5dgyeSTW7S5Bg9RkNG2Yite/24K8ZpnISE3Ggg2B9OgXPvstRvc5AQDwwpcb66Rf+deXG7EmzAX95+0H8fP2gwACOalyMtNQdKAMM1bswrw1e9C0QSp+3bcNGqQlY39pBfI7ZKO0vAr3f7oKyY1btFX7fwzPERDRGAAFzHyDsn4VgIHMfGvIMSuUY3Yo6xsBDERAOfzAzG8q218B8DkzfxDtnrkde3LmZY+p7rtyYHvkZKbhmXkbMKBjDpo0SEWbpg3QPqchpny9ybLJ2yHdcnUXgREEQbCTopduPly5f3uT8O2e8RoionEAxgFAcuNcZEY47s2F22qXf9xcbINkAUQJCILgerhG1V5lxmRxEYB2IettlW2qxxBRCoAmAPZrPBcAwMxTmDmfmfMzsuqHmQfp1jIL/Toc39+0QSo65WaiQKm5ahVm+TILgiBYiOq0lBkjgkUAuhJRRwQa8csBXBF2zDQAYxGw/Y8BMI+ZmYimAXibiJ4AcAKArgB+RAy6t8rCXy8+CYVbinHVoDxc8u/vUVFVg2X3n19bo3Z7cRmaNkxFVkbd3CF7S8rBzMhMT8Huw8ewvOgQRvdpU5t9cM2DBQCAcf9ZjM65mejaIgt3fbwcHZo1xNb9AQ+g24Z1Rc/WjfGvLzfi4V+fhJ2HjuLc7i1w4t2fo7Ka8faNA3HrWz9rLskoCIJ3GT+8OyZ/vqZ2/aQ2TTCgYw5e+XYzFt41FOt3l6BhegraZjfArW//jNM65aCiirF212HMVywJw3q0wJzV9XNytctugO0HItflGN67FU5u2xTbikvx07aDWKvMJ/RonYWiA0eRkZqMgt6tsOPAUSXnl3oRWVPiCIhoBICnACQDeJWZHyaiSQAKmXkaEWUA+A+AvgCKAVweMrk8EcB1AKoA3MbMn8e6X3gcweZ9pchMT0aLrPhT927bX4ZmjdKQGeaju/9IOfo9NAcrHwh4ti7cvB8nt22K5hEyY9bUMJKSCKXlVThWWY1+D7mzMpkgCPGRlpKE78efW/tub5k8El+v24OU5CSc2DILqclJaJyRgq37y5DXPJIRG1hRdAgXPvstXrumP37adgDPztuA35zaBoO75uI2JWPvf64fgD/+92ccKDvuJdU4IwW3DOmMv89cWyee4UBpBf799Uac17NVHasIEHCSWV50CKe0y1aNI5CAMovxe55zQbCCJg1SbXchHd67FT5fESjZuWXySGzaewRLdxzEr/uqOuLEpKq6BvdNW4l7LuyJ3YeP4bZ3l+CZy/uiXU5DMDM6TpiBz/7vTLRukoF+D83B45ecgtvfX4o5fzkbHZtn4lhldb2OaywiBZSJIrCYbnd/XicPvyAIxmnZON3WoLJFE4chMz0ZPe+dBcCeyOKftx1An3ZNQUQoOngUJzTJwNrdJejeqnHc17Q0sliITPdWWU6LIAgJw4Th3QEAb1w30NL7jDurU+3y2zcORG5WOhqmpeCOgm6YdusZlt47SN/22bUp2ts0bQAiMqQEouEZ91Gv0ja7IZbukIyasXjrhoHo1yEbz83bgOfmu6+Yj92c3LYJlslzU4eczDTcdHZn3HR259ptE0f0wJNz1qGswtz62neN6IHTOuWgXXbD2oRzAPD7IV1MvY9bkBGB4AjBkdJ4pYd3RpfmyEhNRo/W1vR4vEbDNH/30aZeN6D22QhSHRaaP/f2s3HtGXmmKYHGGYHvPGj2Obd7yzpKIJERRWA1kkxGlQt6tcKLV56KGwd3wuZHR9RuH3FSK6yaZG/uwbxmDW29Xywu6dcW+R0ix8r4gf552bj57M64elCH2m3h85mdcxvVSyFjhJPaNrHF9u9GRBFYTEsDLq1W8fq1/Z0WAZv3laKgd2skJ1GdUpVEhIZpKdgyeSQuy28X5QrmcXOIqcENPHbJKWjq09q5QVKSAk3TpNG9sWD8ubhxcEc8MLqX6rHpKeY0Y5MvPtmU63gRUQQWc0dBN6dFqEdflQLgdvPpstjlIRumWxetPfvPZ2FAxxwAzlVFa9Ig0NgX9GqFMf3quiCm+DxVbXLI/9+maQNMHNkzopvmSW3qpc6Ji3Y57hoZ2okoAovJcGHqicw0+2T6YcJQFN49LK5zj5o8ARhkzYMF6NoyCy9dnY+v/3YO1u0uiX2SBXz2f2cCAG48q2MdEwgAX08Uz7xtcB1FEItWTdw36vYaogh8RM/WjTH1ugG1dtW/XRB9tHLrOcY9JFo1yUDzRunY+MiI2vQdAHDlwA5RzgrwzqLthu+vRlA5N2mQivbNGjqirBtnpCBJaez2H6nAyW2b4sbBnXDlaYHvZfUuZ5STG9DrIvnQRb0N39OvcwNBRBHYQE+XeMI8evFJOPvEXACBRv6KAe2jHh8tPF4vyUmEJGUu4MNbBuG+X4XXLnKOOwu6xz7IZGoYSFa+j1TFxj1xZI/aRs3fhiF9NG2YJkWfDCKKwAZm/Gmw0yIACPimB/nrBd2QnZkW9fg0kybhggSH+7mNMjR5e1hhJ3/p6npBlbWJCu2kqqam1qTRPLN+3qrg/IUbCc5tuAlSz6UmaEQUgU189bchToug+rJ88oe6UZKvXdsfH94yCACQlhz/y6VWYDvYridpfOomjOgR9/0jofYf1TiQZiXoE7/mwQKc1Lb+ZGf4nIHThKZxtzLHT7wmmvAYAz24wYvOaUQR2ER4Omy3cEq7pnXW++floE+7gFdRitYWWwVWKcIYVERae29m9fEmX3xS7XJFdf28T+0d8BYJfreR5ieMfPeCdkae3BpDurVwWgzHkafNJrxiw8xISaqVVY/nRigf3DwINVHy7GVlaIuaNeM769chG5f1jx6PkKThRlk6szzG4sWr+kXdn5Hqrlczp1F0M6JX6eBjl9FQ3PW0JTBaGhsrGKjT1pySnBTSc4/vngfKKiOaW7ZMHmmrTT40YK1Jg9TayfJwbhvWNep17jV5cjuSHEEau8QOv/GRQNR344xUyz1rnEjQaMCilFCIIrCJZIcms/KaZWJMv7b49NYzox6n1vtPilPm1GQyxe5uxjsa/LfO6NIML12dHzF/+zWn50W9jtkT57FwS/xJ8Ln4cfN+AMCHt5yOId2iK7F4+dwBpwovpuG3An9ntvIBD4zupalR6Z+XjXfGDaqzLTdLvQpbLJiNTd6FXscoQWX21g2nabpXVkYKSo5VWSKL12ib3aB2+adtBwEETG1GOzV/PLcLnplXP8OsE54/PvxZVZERgU3orSRkFlqUwMK7hmJKmFvluoeGx50JtE12Awzp1gJnxTB/2IHWtiU7Mw1f/nUI+uepm9LUJr/18PHvT69dXvtQQZQjnSeYFuXpy/vWbrvnQvNMY38+78R6235zanxVvoxSI7YhAAYVARHlENFsIlqv/FVNYkNEY5Vj1hPR2JDtXxLRWiJaonwSevq+a4tGtt6vo8aAsJaNM+rZ7Y2YQk5smYVXr+mPN64bEPc1APt7a3nNMyN6KhkdEYTmd4rX5KYFM8wrGSmBzkOLkBFhaOyA0d9FrefvVGxCy8aSngIwPiIYD2AuM3cFMFdZrwMR5QC4D8BAAAMA3BemMH7HzH2Uzx6D8riaKwZGj+Q1m9Ym5mAZ2t1+HW2G/TbVxDTFZmGlIjDSoAbTccea37HCrt7zBGei7687s6Mj93UbRt+S0QCmKstTAVykcswFAGYzczEzHwAwG4C7x8YWYbeducrEYa8Tc92X9DOehvrB0cbz0ADGfrvwdNrxuuVqwYiSOTWkBsIrY/PRpmkD1ePOtaBTEJ591S6s/C28hFFF0JKZdyrLuwC0VDmmDYDQ7GE7lG1BXlPMQvdQgseJN7PZF7tdtrd9pJuYkJO/ReP4JrzNpKD38ajc+y3OsaS3XQt1CHjk14HAO2ZgaI+WEV2erxqUhykhcRBf/nUI/vcHe+r4GsXseJBEIaYiIKI5RLRC5TM69DgOjBf19pt+x8wnARisfK6KIsc4IiokosK9e/fqvI07GHXKCbbda1iPlnjkYnN6w4D2OYNmMfIX2U16ij43zBtDipaHYtbYysyKWmro7UsF8zllpafUOhaET4w/NubkOsoMqDvyaJCW7Jn6CTI1rE7Mp5KZhzFzb5XPJwB2E1FrAFD+qtn4iwCEjo3bKtvAzMG/JQDeRmAOIZIcU5g5n5nzc3Od90aJBzsHPFkZKbobQTWG9QiYAbTK/s9LTzF8Tyc5rVOzetsu6nMCBndtHvc1gw3ruLM64byeaoNm89Br6khR8klVR7F9XZLfrl7uqND7EOm/75+GHg/gs9MsdP8o9Spnfsdo92QagKAX0FgAn6gcMwvA+USUrUwSnw9gFhGlEFFzACCiVAAXAlhhUB5BwaxEai+P7R+IKNV4ucqqKLklPMqdw7sb8i4J/hR3jeih6zrrHx6u+156O+bBmIC3b4weZxFOaL8gLTlJd2nNUBfSf15iX+fBqbkIt2PUYDYZwHtEdD2ArQAuBQAiygdwMzPfwMzFRPQggEXKOZOUbZkIKIRUAMkA5gB4yaA8gkVoVSztXVYI3gyyGxozd8Ufoa2/n9ZAZ/W5oKmqT0jyQS0/dWhSvKYN09BE4/Oh1aVZsBdDioCZ9wMYqrK9EMANIeuvAng17JhSANEzbwlx40QkbGoy6a4uZSXBUpBGmKQxMjvIWzcMxO9eXlhnm13pKf73hzNAOnO2qkUJa/E2C0+OqtV0GBpbkpOZhuLSCk3nCdbiPidrwRTMzrGvJRiusto9U3ErHrgAvU0oaq63V35Gl/pzCXrNJvFC0O/mq3b8f3/cFvM8M2IhFt89DD/eVa8fGRdPX97HlOv4FVEENhNvIXenuW1Y/bQAbkatME48qE3uXnN6Hp67oi/SoiiJX/dtg6X3no8fJw5FrxOMKyQtEOmv4aDWoB/WUHgmXjUQ2kEhIrQwKbI3r5mYnIwgisBmmjeyx6/d7L65U2m0nUbt97ptWFdcePIJUXvfAzvmoEnDVLTIsi+FAYF0e6bFG1AVrwfcsUprnAn0jlAuzZdJ41BEESQodqfXzbbJ/OEGgtXmpv9RPa/P69f2x0V926ju08vP95yn+dh43Djj1e/xWoZObGlNvi3dJjHT6t8lBhJmJ5jCtFvPRKVKGUgv8tq1/XHta4vqbR87qANG9TmhtrHtEmHexMzSh9k6AvSCjWFuVjr2lpRrOsfukZ5VsTRaLts/L5BC47ExJ6umGvczMiJIUOz2GmqX0xCdcu3NrmoVfdo2Vd3+wOje6Nehbprqn3T02K0m2MvV09RmqAQdanl03GYp1NLDD+auuiS/neQYCkNGBAmKGYVh/Ep2Zhomje6FI+Wxe405mWlo3igd+46UW17KMRbBXrGeTvcNgzvizuHddd/LbYV6wt1Z1ZDGPzIyInCAm85Wz2djJmZmHg0Sq9ylW1hsgmfW1YPy8PshXUyQxj6CE6anRBjRqNEgNblOMJnW893kKgzot/kndnpL/YgiSFCssNef1LZJPW+LrIwUZKS66zFqZpNn1nHc0SgGG7cXfneqoeu0y4kdHR4ccd41Qv9owgq0dPZDG3/RA3UR05AT2NBulFuU8+cfY07BwI7NcPv7SwE4U3BcUCfYuBnJcDrztsER6xCEEqxnPO6sznHfy0wSPIO95YgicIAcG1I1V1iY/C3UFt3W4zUPEonwxjArPQUlGuY5QtGaIiSveabjcyKhaNEDoccM6twMp7Zvapk8XsNdY3qfcL0N5fGszA8ffKGeuqyPZffwEu/eNMiUvEZGCW8MNXWSE6QjrTegrEuLLHz0e28U07EDUQQOYFZxkk8iVIV6ZWw+/nWl9fn8RvcxJ2jKTF6/tr/t9+yc28iUvEZGSQ1znfGTl4w2neef70Mvogg8TPMs9UnRri2y6pQgNJsuuVmWXdso3Vq5VzarCRaZAYABHXO0pc92xzy3YcxIgudnRBF4mEiPfmiDYAUntW3iKvuwn/nNqce9uEJjR967aVBc9Qys5M3rB1p2bdEDxnDXkyLoItLQ3yv1YwVtvHjlqXX+hpKvpE0AUG8UGF572Gn0Fs3Rg5/MYFYgisDDROoF+fmlaKCjiIxXKOjdWlmK/rvqKaBjRe98aPfoOZaqLMxFpaXzI6OGyIgicAmTRusvqk0g3H5e/ToBmSbl4vcas247C00NlpX0Gh00BH+p0aSB+dliX7km+kS9FdHutUgjbwhDioCIcohoNhGtV/5mRzhuJhEdJKLPwrZ3JKKFRLSBiN4lIn+9xSEM61G/AEosiIB+efW/cj09w0QhOYl8N1G8ZfLIiBlQo3HN6XnHzzO5Af3tgHYR91npWSWTxcYwOiIYD2AuM3cFMFdZV+MxAFepbP87gCeZuQuAAwCuNyiPZ4mnr0QQlzi/MUOpgbBQKfGo57kZ0i0XAHD/qF6W2ev/rDJCDWLFKCSIXQWfEhWjimA0gKnK8lQAF6kdxMxzAZSEbqNAGOS5AD6Idb4fsLuQTKLhB3VIBPQ8oXHtMqAvC6hatHlLk0pF1uLgYzzqlBOcu7nHMaoIWjLzTmV5FwA99o1mAA4yczAGfgcA90Uo2UQ8ekDyqxzHUvuzSwg1+QVHgjU6HpxrTs+rs75l8kh0TpAaEkBgpCPER0xFQERziGiFymd06HEc6NJa9jYS0TgiKiSiwr1791p1G9toZkK+IYKMJEb3CfQCO+UmdvHytJQkdA2ZD6gdEei4xvm9WmHVpAvMFcxFxOoWSccpMjEVATMPY+beKp9PAOwmotYAoPzdo+Pe+wE0JaKgi0tbAEVR5JjCzPnMnJ+bm6vjNu5kcVhlq3ja8yQi1cliP9KpeWIrgnUPDccJIVlBGymeYXo7Ag3TrPUoc7Jb0ihD/X9b82CBzZJ4D6OmoWkAxirLYwF8ovVEZQQxH8CYeM5PNKIN8SM+yASkq5QaFBKbLZNH1pqJ0lwWPezkADVSJLUfvej0YvQpmgzgPCJaD2CYsg4iyieil4MHEdE3AN4HMJSIdhBRcHx6J4C/ENEGBOYMXjEoj2eJ9v5EepAb+TReIJTgYN+vFrIWUSZ7/fqdCPox1JIw834AQ1W2FwK4IWRdtXoJM28CMMCIDIlCPLb+YARxw7RklFVUmy2SJ2AAN53VCWd2be60KI6Rk5mG4tIKp8UQPIy7xpU+ZcvkkYZsq3eP7GmaLG7j4lNjO5JNGNEDg7t6f97IbC7NjxzcZRVuy28UikwVR0YUgUvQOyB447rjA6krBrY3WRr30LSBb4PNNXOsUn00eONZnWyWJDK/H+KOkpaCOqIIXMIJTTMwWId5o1+Hut5C5/XUn6LC63TOzcSZXfxrEgry0e9Px/Q/Ol8hDYjcobmjwB1F7gV1RBG4hIZpKfiPjoyQ4YnlXro632yRXM/c24fgEgfMH26je6vG6HWC8xXSAOtrYcTLyJNbY0DHHKfFcC3idiK4GokB8hYtskxOWWGQ684I1Ad//or6tRyE48iIwEHUcudf0Mt/Jh4hsdj4yAjcOLij02IAAE7t0NRpETyBKAIHUZvkvfK0DoaumWg96AT7d3xBchJJOgePIaYhB7mjoBtuOttczw4pUym4gbxmx1N+bHpkhIOSCFqQEYGDpKckm25TTbQCHQn27/iG3w5oVxv5nuRg50TqdWhDFIHLMJoWINHqFYuJwZsQES7vb79Hl9kjbL8gisAjFPRqpek4K6tACYIefFAiImGQOQKXEakDrMU/+9s7z4mYgdGryHjAu+gpmmMWYgqKj8RqNRKINiG55wFtk8BtsxuaX3rQaeS99ixOFE1yc64jNyOKwCMkJ/nzp5Ienndxg2nIiVGJF/Fn6+JBEszioxmZKzaGnvxVZlPtAtPQnNW7bZfBi8gcgUdING8gwV6euPQU2+95y9md0bddU9vvG0ppuT/rdOjFp/1M7zHqlNh5+RMRUX/mcFEf+5+fdjkNbU8K2Kpxep11MQ1pQxSBy+ic20h1+6DOzdDH4d6VE0T0opIRkiaCzgNOBnXZydWD8uqsd22h/j4JdTGkCIgoh4hmE9F65W92hONmEtFBIvosbPvrRLSZiJYonz5G5EkETgjzFgqlfU5DAECrRPMMikKkyeIF48+1WRJv8tBFvbFo4jCnxbCNcIU34qTWDkniLYyOCMYDmMvMXQHMVdbVeAzAVRH2/Y2Z+yifJQblSWiCveMJI/xT5CPSiKBFVrr6DqEOGanJyPXxd5VoKVeswqgiGA1gqrI8FcBFagcx81wAJQbv5TuW338+1j88vHY9RXEh7RCS0Mtv3HpOFwCSekLQhjwm2jCqCFoy805leReAeJLpP0xEy4joSSLyb9clhDsKuuGOgm7IykitEyl8z4U9AADJPnq6w//Tm6X2raADGRFoI6b7KBHNAaCW6GZi6AozMxHpnaKfgIACSQMwBcCdACZFkGMcgHEA0L594hZrB4DfD+miur1pw0Ahd1+5koa9yE5EqwrepWlDyb2lhZiKgJkjzjQR0W4ias3MO4moNYA9em4eMpooJ6LXAPw1yrFTEFAWyM/P93VrkJ3pn4fbRypPsIBozhfCcYyahqYBGKssjwXwiZ6TFeUBChh8LwKwwqA8Cc+mR0agdRP/PNwtwvzCfd0DEDQx67aznBbBcxhVBJMBnEdE6wEMU9ZBRPlE9HLwICL6BsD7AIYS0Q4iukDZ9RYRLQewHEBzAA8ZlCfh8Ys/eJDf9m+PxXcfH5SKZUiIhVotcCE6hlJMMPN+AENVthcCuCFkfXCE88UZXIhKUhKhWaOQUYEoAiEGMj+sH4ksFjxFeqo8skJ0/DZqNgN5qwRPkZGajC2TRzothuBiRA/oRxSBIAgJhZ/ibMxCFIEgCAlFY6nbrRtRBIIgJBQZqcno3irLaTE8hSgCQRAEnyOKQPAMF/U5wWkRBI8gSQn1IYpA8AwN06WyqqANUQP6EEUgCELCIQMCfYgiEDyDvNuCVkQR6EMUgeAZ5OUWtBKpxKmgjhhdBUFIOP5wThes/OWQ02J4BlEEgiAkHAW9W6Ggt1o9LUENMQ0JnqBzbibO7JLrtBiCkJDIiEDwBHNvH+K0CIKQsMiIQBAEweeIIhAEQfA5oggEQRB8jiFFQEQ5RDSbiNYrf7NVjulDRN8T0UoiWkZEl4Xs60hEC4loAxG9S0RpRuQRBEEQ9GN0RDAewFxm7gpgrrIeThmAq5m5F4ACAE8RUVNl398BPMnMXQAcAHC9QXkEQRAEnRhVBKMBTFWWpwK4KPwAZl7HzOuV5V8A7AGQS4H0gOcC+CDa+YIgCIK1GFUELZl5p7K8C0DLaAcT0QAAaQA2AmgG4CAzVym7dwBoY1AeQRAEQScx4wiIaA4AtRC9iaErzMxExFGu0xrAfwCMZeYavfnCiWgcgHEA0L59e13nCoIgCJGJqQiYeVikfUS0m4haM/NOpaHfE+G4xgCmA5jIzD8om/cDaEpEKcqooC2AoihyTAEwRbleCRGtjSW7gzQHsM9pIWLgdhndLh8gMpqB2+UD3C+jHvk6qG00Glk8DcBYAJOVv5+EH6B4An0M4A1mDs4HBEcQ8wGMAfBOpPMjsJaZ8w3KbhlEVOhm+QD3y+h2+QCR0QzcLh/gfhnNkM/oHMFkAOcR0XoAw5R1EFE+Eb2sHHMpgLMAXENES5RPH2XfnQD+QkQbEJgzeMWgPIIgCIJODI0ImHk/gKEq2wsB3KAsvwngzQjnbwIwwIgMgiAIgjG8Glk8xWkBYuB2+QD3y+h2+QCR0QzcLh/gfhkNy0fMER19BEEQBB/g1RGBIAiCYBKeUgREVEBEa5XcRGrpLFwlDxFdQ0R7QybJb3BCzjCZXiWiPUS0wmlZgNjyENEQIjoU8h3ea7eMKjK1I6L5RLRKyaH1JzfL4tLvMIOIfiSipYrcD7hZFje+y0GIKJmIfiaiz+K+CDN74gMgGYGI5E4IRCcvBdDTzfIAuAbAc05/d2EynQXgVAArnJZFizwAhgD4zGk5w2RqDeBUZTkLwDqnnkUtsrj0OyQAjZTlVAALAZzmVlnc+C6HyPYXAG8b+Y29NCIYAGADM29i5goEYg9Gizz6YOavARQ7LUcQt8mjBWbeycw/KcslAFbDofQobpJFDxzgiLKaqnwcmbB0kyx6IaK2AEYCeDnWsdHwkiJoA2B7yLrTuYm0yvMbJf32B0TUzh7REo5ByrD9cyLq5bQwoRBRHoC+CPQiHSWGLK77DhWTxhIEMhLMZmbHvkONsrjxXX4KwB0AaoxcxEuKwIt8CiCPmU8GMBvHM7UK2vkJQAdmPgXAswD+56w4xyGiRgA+BHAbMx92sSyu/A6ZuZqZ+yCQXmYAEfV2sSyue5eJ6EIAe5h5sdFreUkRFAEI1cJRcxPZQEx5mHk/M5crqy8D6GeTbAkDMx8ODtuZeQaAVCJq7rBYIKJUBBret5j5IzfL4tbvMAgzHwQwH4F6JY4SSRaXvstnABhFRFsQME2fS0Sqwbux8JIiWASgKwWqmqUBuByBXEeulUdJxBdkFAL2W0EHRNSKlFS1FEhjnoRAwkInZSIE0qGsZuYn3C6LS7/DXFIKVBFRAwDnAVjjVlnc+C4z8wRmbsvMeQi0P/OY+cp4rmU06ZxtMHMVEd0KYBYCHjuvMvNKt8lDRJMAFDLzNAB/JKJRAKoQmBC9xil5gxDRfxHwImlORDsA3MfMjuV4UpMHgck6MPOLCCQlvIWIqgAcBXA5K64SDnIGgKsALFfsygBwl9LbdoUsANoDrv4OWwOYSkTJCCim95g5fvdHC2Rx+7tsJhJZLAiC4HO8ZBoSBEEQLEAUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPkcUgSBEgYiahWSc3EVERcryESJ6wWn5BMEMxH1UEDRCRPcDOMLM/3RaFkEwExkRCEIcKDn+P1OW7yeiqUT0DRFtJaKLiegfRLSciGYqKSBARP2I6CsiWkxEs8KiVQXBMUQRCII5dAZwLgLpB94EMJ+ZT0IgknekogyeBTCGmfsBeBXAw04JKwiheCbFhCC4nM+ZuZKIliOQcmSmsn05gDwA3QD0BjBbSfuTDGCnA3IKQj1EEQiCOZQDADPXEFFlSC6fGgTeMwKwkpkHOSWgIERCTEOCYA9rAeQS0SAgkDraLQViBEEUgSDYgFLOdAyAvxPRUgBLAJzuqFCCoCDuo4IgCD5HRgSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPgcUQSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPic/wcvziJ0eY2VRAAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "librosa.display.waveplot(samples_out, sr=sr)\n", - "print(len(samples_out))" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "infectious-welcome", - "metadata": {}, - "outputs": [], - "source": [ - "import librosa\n", - "x, sr = librosa.load(wav, sr=16000)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "musical-anatomy", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float32\n", - "float64\n" - ] - } - ], - "source": [ - "print(x.dtype)\n", - "print(samples.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "lucky-paraguay", - "metadata": {}, - "outputs": [], - "source": [ - "sf.read?" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "annual-christmas", - "metadata": {}, - "outputs": [], - "source": [ - "librosa.load?" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "infectious-seeker", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 47, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.allclose(x, samples)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "pregnant-conditioning", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import random" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "id": "logical-happiness", - "metadata": {}, - "outputs": [], - "source": [ - "np.random.uniform?" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "rocky-plastic", - "metadata": {}, - "outputs": [], - "source": [ - "random.uniform?" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "id": "focused-compensation", - "metadata": {}, - "outputs": [], - "source": [ - "np.random.RandomState?" - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "id": "centered-repository", - "metadata": {}, - "outputs": [], - "source": [ - "random.sample?" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "id": "inner-invite", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array(['3', '5'], dtype=' 1.0, speed up the audio;\n", - " speed_rate = 1.0, unchanged;\n", - " speed_rate < 1.0, slow down the audio;\n", - " speed_rate <= 0.0, not allowed, raise ValueError.\n", - " :type speed_rate: float\n", - " :raises ValueError: If speed_rate <= 0.0.\n", - " \"\"\"\n", - " if speed_rate <= 0:\n", - " raise ValueError(\"speed_rate should be greater than zero.\")\n", - " old_length = samples.shape[0]\n", - " new_length = int(old_length / speed_rate)\n", - " old_indices = np.arange(old_length)\n", - " new_indices = np.linspace(start=0, stop=old_length, num=new_length)\n", - " samples = np.interp(new_indices, old_indices, samples)\n", - " return samples" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "tracked-purse", - "metadata": {}, - "outputs": [], - "source": [ - "samples, sr = sf.read(wav)\n", - "samples_out = change_speed(samples, 1.0)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "steady-mileage", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ipd.Audio(samples, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "regulated-google", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ipd.Audio(samples_out, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "homeless-forge", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples_out = change_speed(samples, 1.1)\n", - "ipd.Audio(samples_out, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "exciting-blocking", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples_out = change_speed(samples, 0.9)\n", - "ipd.Audio(samples_out, rate=sr) # load a NumPy array" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "through-botswana", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "66078\n" - ] - } - ], - "source": [ - "print(len(samples_out))" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "cellular-violence", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting matplotlib\n", - " Downloading matplotlib-3.4.1-cp37-cp37m-manylinux1_x86_64.whl (10.3 MB)\n", - "\u001b[K |████████████████████████████████| 10.3 MB 691 kB/s eta 0:00:01\n", - "\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (8.1.0)\n", - "Requirement already satisfied: numpy>=1.16 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (1.20.1)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (2.8.1)\n", - "Collecting kiwisolver>=1.0.1\n", - " Downloading kiwisolver-1.3.1-cp37-cp37m-manylinux1_x86_64.whl (1.1 MB)\n", - "\u001b[K |████████████████████████████████| 1.1 MB 45.9 MB/s eta 0:00:01\n", - "\u001b[?25hRequirement already satisfied: pyparsing>=2.2.1 in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from matplotlib) (2.4.7)\n", - "Collecting cycler>=0.10\n", - " Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)\n", - "Requirement already satisfied: six in /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n", - "Installing collected packages: kiwisolver, cycler, matplotlib\n", - "Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1\n" - ] - } - ], - "source": [ - "!pip install matplotlib\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "import librosa.display" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "undefined-parade", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8K0lEQVR4nO2dd3hUZfbHvycdQoAEQpEWmlQVJICoKApqABdcF8u6KlbUXX+77rqrIFZs7Lr2si5WXHXtrigI0myoSFB67xDpoYQEUs/vj7kTJpM7M/fO7XPP53nmye33ZObe97zveU8hZoYgCILgX5KcFkAQBEFwFlEEgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPscURUBEBUS0log2ENF4lf1/IaJVRLSMiOYSUYeQfWOJaL3yGWuGPIIgCIJ2yGgcARElA1gH4DwAOwAsAvBbZl4Vcsw5ABYycxkR3QJgCDNfRkQ5AAoB5ANgAIsB9GPmA4aEEgRBEDRjxohgAIANzLyJmSsAvANgdOgBzDyfmcuU1R8AtFWWLwAwm5mLlcZ/NoACE2QSBEEQNJJiwjXaANgesr4DwMAox18P4PMo57aJdcPmzZtzXl6ePikFQRB8zuLFi/cxc274djMUgWaI6EoEzEBnx3HuOADjAKB9+/YoLCw0WTpBEITEhoi2qm03wzRUBKBdyHpbZVu4AMMATAQwipnL9ZwLAMw8hZnzmTk/N7eeQhMEQRDixAxFsAhAVyLqSERpAC4HMC30ACLqC+DfCCiBPSG7ZgE4n4iyiSgbwPnKNkEQBMEmDJuGmLmKiG5FoAFPBvAqM68kokkACpl5GoDHADQC8D4RAcA2Zh7FzMVE9CACygQAJjFzsVGZBEEQBO0Ydh91gvz8fJY5AkEQBH0Q0WJmzg/fLpHFgiAIPkcUgSAIgs8RRSAIguBzRBEIgiD4HFEEgif42/tLsXirpKASBCuwNbJYEOLl/cU7kJ6ahH4dsp0WRRASDhkRCIKQcKwoOoSZK3Y5LYZnEEUgCELCMeGj5bj5zcVOi+EZRBEInsGDsY+CQzDkYdGDKALBM8irLWhFOg36EEUgCELCIYpAH6IIBEFIOEQP6EMUgeAZ3l64zWkRBI/gxWSaTiKKQBAEweeIIhAEIaGorK7Bml0lTovhKUQRCIKQUBworXBaBM8hikAQhISiRqYHdCOKQPAUFVU1OPHuz50WQ3AxNTJRrBtTFAERFRDRWiLaQETjVfafRUQ/EVEVEY0J21dNREuUz7TwcwUhlKMV1aioqnFaDMHFiCLQj+Hso0SUDOB5AOcB2AFgERFNY+ZVIYdtA3ANgL+qXOIoM/cxKofgE8hpAQS3UyP9BN2YMSIYAGADM29i5goA7wAYHXoAM29h5mUA5CcSdDFj+U7kjZ/utBiCh5ARgX7MUARtAGwPWd+hbNNKBhEVEtEPRHSRCfIkNHnjp2P/kXKnxbCNtWFugEkyIhBiMPKZb5wWwXO4YbK4AzPnA7gCwFNE1FntICIapyiMwr1799orocvYU+IfRRDetyMSTSBEp7SiunZ5n486TUYwQxEUAWgXst5W2aYJZi5S/m4C8CWAvhGOm8LM+cycn5ubG7+0HmD+mj1Ysv1gve3HKgMPeLWf/ONkmC8YYM9hUQRaMEMRLALQlYg6ElEagMsBaPL+IaJsIkpXlpsDOAPAquhnJT7Xvr4If3jrp3rb/zFzLQB/KYLw/3Tu6t2OyCF4E5kv0IZhRcDMVQBuBTALwGoA7zHzSiKaRESjAICI+hPRDgCXAPg3Ea1UTu8BoJCIlgKYD2BymLeRr8kbP73OROn+0kDv5kCZfyMn//TOEgCSVEzQhjwm2jCleD0zzwAwI2zbvSHLixAwGYWf9x2Ak8yQIdFQM4UHH+oHPl2FId1a2CuQQ0R6kcurapCRmmyvMB6kpoZRw4yUZDdMB9qPjAi04c+nw8UcVHr7as9vsZJDZfO+UjtFcpRIJQeHPfGVzZJ4k7/PWoPu98x0WgzH+GHTfqdF8ASiCFzG8qJDqts37DmCbzfss1ka54nUodtx4Ki9gniUVb8cRpWP5pQ+WVLXT0UUgTZEEXiEV77d7LQIjuCfJsxanGgQDx+rxLIdB2295+qddeNOkpOkidOCfEseoVri5gUDXD7lB9vv+cQX6zDquQW23zcUn06N6Ea+Jo/gp+F9KDLXZ4yFm4odu3eVCzovBb1bOS2CJxBF4BFq/KoIxDhkiIpq5xrjJAeiwMOfl7Rk8SzTgigCl1J0sO5kaLWG9rDkWCXKKqosksghRA94FicUgRAfoghcRiRTiJYRwbAnvnLEFmwloge8i+gB72BKQJlgPdOX74x5zO7D5ThyLMFGBIJnkRGBd5ARgcsw+u5UJ9jsqqSS8C4rIsTEWMl7i7bHPkiohygCB2FmVJk8mZdoCelED3iTWSt3YeFm+z2WDpRV1lkXZwNtiCJwkBe+3IguE80txJ5oikDwJou3HqhdPmPyPAclEbQgisBBPltW3+4f+gLFQ6LpgQT7d3xDqIUz3ANOcB8yWewgq3cerrftqTnrHZBEEMzjdy//gAUb3JHjp0hyUmlCRgSCq5E5Am9xoLTCNUoAAB79fA0A4L8/bkNpuXjURUIUgUuoqq7Bawu0J5Yrr6qus/7yN5vMFsn1/PX9pfhmvb/rVwPA7sPHsOvQMafFAAAcrayOfZADTPhoOeau2eO0GK5FFIFL2LK/FA98qr0427pdR+qsPzR9tdkiuZ4PFu/AB4t3OC2G4/z6+QU4+7H5TosRFSkx6m5EEbgE0hlA8Kvnvq1dTuSc6+L+F5tjVTUor6rvhjxzxS7bZYn0GF8/tdBeQQRdmKIIiKiAiNYS0QYiGq+y/ywi+omIqohoTNi+sUS0XvmMNUMerzH6uW9hJI7s2tcWmSaL23htwZaYx0xfthMb9hyJeZzfePyLtbbfkww9yYJTGFYERJQM4HkAwwH0BPBbIuoZdtg2ANcAeDvs3BwA9wEYCGAAgPuIKNuoTF5j6Y5DukcEwPGo20oHM0y6gT+8/RMeneE/01iQSHmoJMODoBUzRgQDAGxg5k3MXAHgHQCjQw9g5i3MvAxAeIt1AYDZzFzMzAcAzAZQYIJMniPaOxvpRT+s5BXya60C4Pj35tdGr7i0AgePVsY+UBCiYIYiaAMgNMHHDmWb1ecmFNESdHW6a4b6DpZcPMf/e/9ogsv+/T0qlDmBaC6RTjwaTipkv9bsMAPPTBYT0TgiKiSiwr17ve8yOPbVH+usx/MCMRiLthiLRE4U9h4pd1oES7lhaiEOllUAABZuLsbhY4FRQFKSuxSgk9IcijAyyhs/3WZJvIcZiqAIQLuQ9bbKNlPPZeYpzJzPzPm5ublxCeomvlpnXJkxu6McoJN8suQXAMDS7QedFcRi5qzejeUh2TyDvX09De/KXw4ldKMo44H4MUMRLALQlYg6ElEagMsBTNN47iwA5xNRtjJJfL6yzXfENyIQL40gLusYW8LRiuPBWkG3Wj05/x/6rO6E+okTP8cvCZQH6JVvowdV+t2MGg3DioCZqwDcikADvhrAe8y8kogmEdEoACCi/kS0A8AlAP5NRCuVc4sBPIiAMlkEYJKyzXfE6zXk10nScOL5/rwGA9hbUn58BdE7EOH7UlPqvu4V1TXYsr/UPAEd5sWv/BddbxamzBEw8wxmPpGZOzPzw8q2e5l5mrK8iJnbMnMmMzdj5l4h577KzF2Uz2tmyONF4mnGGJKLx2/0f3gOAODCZwMBhXqem68Vc+ScVbstm1j9YlXkCGIr3ZwPllVICnYDeGayONFRy0QaC2ZgW3H9Hp0f4wqqaxj7E3zCOJQ9JeXIGz8dO+PIMXTDG4VY8Ysy32By23n3/1ZE3Ldwk3WD/QofPvNmIorAJcQTgs9g3Pnh8nrbS3xat7jfQ3NwJIEzTKqN/lbF0YEArIk9eXZu9BTqKckWmu9kMGAIUQQeJpJZyM+eREcSUAlGGy2GNq16TCMXv/CdAYnUeXz2uqj7Uy1UBFpqdYsZNTKiCDxMpBffz7bSRExSN/zpbwAAN7+5uN6+X0JMQ+EeQJo8ymycY09Jsq65qapOvN/dTkQReJhIj77VL8XmfaU49/EvLb2HoI1nQswxySE+tC99vUlTL9lORj+/wLJra/lXE7GTYBaiCDxMJM+PvUfK6/icm83S7Qexaa873Q79Oj8C1B0JPjxjta8mz2tcpvS8higCBzArsGXwP9SLkVz8wnf44zs/m3IPNYI9q2U7Dlp2j3g5/8mvbb/n/iPl2HnI+cCscM8ZTSbCBGk/tfwbEnwZGVEEDvB+ofVVtfYctq50YVCPXfHSQsvu4SUum/IDBj06z2kx6plH/NRJ1jIiCDUN7Tl8DP+cZX+9BrciisABNu61vohKekqyZdcOvnNHyqtw+FhlbQI0vxJMBuc04SPNkjhcafcfKa/NbBqNPSXHcOGz3+i+vlXoVXqzVu7Cc/M3WCOMBxFF4AQ2jFDTUqz5af/15UY8FtKTuuj5BRj5tHsaBD9jxgCg30Nz8Pjs2D3ldbuOYEXRYcx3SUF4LebW0EN8NFjShCgCB7DDVmmFz/aew8fwxOy12BVidtq0txTbDzhvHw+lyvYoU3fYnoMN3UOfrTJ0nc0aHAGCHkrXvu6OMqnSsBtDFIEDvPjVRsvvkZps/k874JG5qPSAv/Zpj841fI1v1+/DnCh5c9xI0Ab+2bKdms+pZq430f39pv0xz7M0SjgONM0RuP/RdQxRBAlKsh/yMkdg3xFjNvtjldW48pWFuOENbWk/gpHcTuf6DwaU6/GXf+P7rXFNdLvt8dLSyIcqC1EKdRFFkKDYnZX5UFklikvdMWlqlEgpHf73cxG2F5fV2ZY3fjoOlrljsjyoAPQElu+LM9bAbQ2plhFBoVLNb8n2g1KbIAxRBIIpjHnxO5zzzy+dFsMUfh0hD89t7y7B8Ke/qW1EIjWi63eXOJLmI9i21dYs0EC8YrqtGdXSrr9bGCiPftHzC7B2t/Wee15CFEGCYnehlvV7jkSsGZtIHCmvqv0/Rz+nnjLhvCe/xmfLfjHlfnrNTXoVULw943g71Fa5TuuVp9rHiRnVEEVgM1p8tM3AZSZcz3Kssn6qji9W7cayHQexN4pZZXtxme3eSzXMuhv2eEcu8SqQZIs6KHrzCM1zidurW0hxWgC/cfZj6mkh3M6bP2x1WgRHmLbkF1zav12dbXd8sCzmef/8Yh3++UUgLfO8289Gp9xGlsgXCrN+k42aHtDiaBCvaSj82oFyq8aVQ2m5vtxaRh0KEg1TRgREVEBEa4loAxGNV9mfTkTvKvsXElGesj2PiI4S0RLl86IZ8riZeCpKxYOeouZaWLBhX8xj3ORSmDd+OraaUI+3vEpfA7NbJbXHLpt+83gaZ7XEhdecnhf7PBMmW//435/RccIMw9cBgN++9IMp1/ErhhUBESUDeB7AcAA9AfyWiHqGHXY9gAPM3AXAkwD+HrJvIzP3UT43G5VHCGD2CFzL9aqqGTsOlMU+0CYueMp4AroXv9qky3wy8JH6MQzlNpmInpu3Xr+tXOWEBqmx05METexNG6YC0G4qenjG6trlaUvNmUcRjGPGiGAAgA3MvImZKwC8A2B02DGjAUxVlj8AMJTsns0UbGFFUXylE63gWKXxBrjo4FHsKTHWo4+3ULxeO/yc1XtQVqEvv1BQyYUXtYl5niLbwbJArimtrsMzV+zSdR/BHsxQBG0AbA9Z36FsUz2GmasAHALQTNnXkYh+JqKviGiwCfIIMG+y+Ok565E3frpmO266RTmOnOSLlbt1N7ChBL+6eWt266oTce7jX+m+l16dEwyGGxXiAaXlpw5VbuWVNVEnztX4Zv3e2uUPF1ufjTfIT9sO2HYvL+H0W7sTQHtm7gvgLwDeJqLGagcS0TgiKiSiwr1796odIoRg1oDryTlKHVqNDcy/vrQ+fYaVbNtf37R137SVmLE8/p5sMLfUda8X4sOftDd6m/fpn+PQa7sPVrNT844KsmT7wXoeUKHmshpm3VXxrnrlx9rl299fqutcI1z1sqROV8MMRVAEINStoq2yTfUYIkoB0ATAfmYuZ+b9AMDMiwFsBHCi2k2YeQoz5zNzfm5urgli28/8tfa5rH38cxEmf77GtOuVa3R7/XFLsWn3NAO9LpwPfLpSdbtZIyyrA830KoKgPITjpqjwpIgXPb8A05fXzV8UepcaZs/UyRaLtDpmKIJFALoSUUciSgNwOYBpYcdMAzBWWR4DYB4zMxHlKpPNIKJOALoC2GSCTK5ES1ZHMzFzGOxESL4ZMRcHXJD+Yc2uktrl+6apKxqz0Psz7VGikEvKq2o7DkT1A79Ckw3OWbUbEz5aXrs+6NF5ltYjNpMjcdRo8AOGFYFi878VwCwAqwG8x8wriWgSEY1SDnsFQDMi2oCACSjoYnoWgGVEtASBSeSbmdldXUoPk2JiZjAn+nvvLtpm+BoPTzeWkjmIkY7k32fWHZlZ2Xs2cu3QrKNDH/8KRREmkN9auDXuHEWRsHO0LNTHlIAyZp4BYEbYtntDlo8BuETlvA8BfGiGDF7g5+0Hbb2fmcnQnJhkqzAh5fUeHXl3AHsUXg0zki2K/TaiCJbtOATguPkkkreTFeaVb9fvwzndWph+3Vh8sHgHxvRra/t93YZEFtvIpzb7Ta/epc2V81hlNVKSCClRahjoUSo7DpRh6ndbUF0D3Pur8JASe9HTZkWbMDXa9oUGmllpZRv13LeGrxEcSB4N+T5CzXRWqLDDDuWpWhMh06zfcNpryDfocRs0C2Zttv3THp2LP79X13PDSGTuut0leOXbzXh1wea4zg9i57TeoaOV6H7PTBRZVG0tNNDsCgujYM2YE3ng04A57faQZ+KRkEAwo7y1sH66khnLtRfTMROZOw4gisAmzAjJj4d/zFqLT5YURU1NfLCsEou3FGNEWO3hrSqulFogIlMK45jxkgZNJU/MXocNeyJnvgyaQdbuLlHdb2Z50cKt3vBlX150qHY5NysdAOpVM4uHiR+vqLet1IGOEmB+KhavIorAJpxSBKt3Hsaf3lmC/g/PiXrcL4eOYVXYMDleiaurzUkkZsYrGjRzPzN3PSZ8tCyiDf0HDeUZ7cSuLLWxCH5fF/RqBSDgITTXosydRmstx4O4kwYQRWATTqU//3KtvuC7UL/7eJVXZnpKxHTDI5/5RrOZzAzV2TgjtXZ50ZYDmLtavQ7xLW/9FPU6T89db4I0x4kVLOaW2g6/ejYw53CssrreiNFsXv7WmCkxHtxWctMpRBHYhFpyLzdSUV1TayapjtNr595PVkQ066z85TCKy7TlpTHjK5uzeje+XX88c2pVnF418UT5RiPWpG60iWs7CY4Sd6mMGBMBtyhcpxFFYBNujbwM9xXfsq+sdlu8jeb6PUdUba/BiWu7g9OufOV4WgG1uYvwOsR2EKuhj/e7F/Tx1sJtEWtU+wlRBDYRy0bvFGdMnldnfcQz39TWHq4wkD5ZLYIz2LZpNZO99p35poLMtPoe004M1oKms7zx01UnsZ3yoonEzJXHcy01zrDO6/z8J/Un2gO0FdOJxHCLTV5eQBSBDdz3SX0vCSdYExJXMH/tnohFV4K9UbMnLINzDlU1NZpGBduLzXflvOGNRfW2GVF48ZKcRLXRudtVajhMX+YuRRDK4WPWpWlYF2dReSdSoCQSoghsYOr37ijzOPW7LVinuEde+9oifLY0emNz1GQ7ddA8du7jX+EFh7KUqtUoqM2waiPJSVSraIN/3164Fa8qE6bSrGmnvKpad/ptoS6iCHzEf3/cjoKnvq5tkGOl/73nf8ZHMmUVVTh0tBJ546fjlAe+qN3+3LwNMc8deXJrw/ePRkVVDUqOVWKPSnlJqzl8rKr2d2jeKA3b9pfhro9XYJLiQtmmaYbtMrmFw8f0TeA+aILbad746Yav4WVEEVhMvNWprKKGgVIDRVb00vPeWfjzu0sA1E1lfUxDLeDcRumWyJQ3fjqKSysw/sNlOOn+L9CqSQNL7hOLexST4dpdR/C/JXUztw/u6s1U62Zw8v1f6HpvzPbo8iOSa8hiXvgyds/XbuzuAc9TCUA6s0tzW2UI59QHZ9cun9G5me15oIDjMR53fby83r5KB+Yt3ERVDSNNmQA+Ul6FBRv2ITcrHae2z6537Jqd6tHgejlUVokmDVNjH5iAiCKwmNW7zHlIzWTYE8aLuhulb/tsrN55GN1bZalGd57ywBe2+Xh/t9FdUcX//XEbSn2eNz9oNnt9wWbcr+Q+atIgFUvvO7/esfs11kuOxZ0fLsOLV/Uz5VpeQ0xDVuMuy5BrWLPzMIY//Q1mrdxdp3e+YMM+5I2fbmugzzQHRgPRmPDRcvy07aDTYjjK1uJSfL1ub60SUMNsZbm86BAu+/f3pl7TK4giEBzhi1WBVA83v7kYxaUVtUFdoYnO/IzfRwQFT32Dq1/9sc628FCBXvfNwnuLtqNBarIp9yw6eBQLNxdj0KOBTLHrd5dETdaYSIhpyGJ2OeCR4kUG/2M+crPSffPixcIrGUrtpKyiGoVbivHsvA14/NJTAAB3fLjM9PvsPHQMX6zchXH/WQwAmP3ns9C1ZRYAYPqyX3Bqh2y0dsjBwCrIi4EY+fn5XFhY6LQYmjj7H/Ox1YEUBoKQyJzYslHcwWfxsO6h4UhOInS+K1CIccvkkZbfc29JeW3678rqGqQmJ+FIeRUapcfffyeixcycH77dFNMQERUQ0Voi2kBE41X2pxPRu8r+hUSUF7JvgrJ9LRFdYIY8bkKUgCCYz+7D9o4c+z74Be6ftrLOttLyKqwwYMqsqWG8vXAbqmsYh8oq8dScdTikFBZiZvR/eA7W7y5BRVUNuk78HPPW7Ebv+2ZhT4n5VgbDioCIkgE8D2A4gJ4AfktE4fUJrwdwgJm7AHgSwN+Vc3sCuBxALwAFAF5QrqeL8qpqw+kQmFk1TL2sogoFT31de/2t+0s1ZYZk5noyScpbIRx5JOLD7qyhpeXV+M8PxzMElFdVo9d9s3ChkqZ7za7DWKc02sFU7rHaiU37juCuj5dj9c7DmPr9Fjw1Zz3+/fVG/LTtAK55LZAKZe+R8tpMtde9HrCC3PTGYny4eAfOffzLOterqKrBzBU7sWmv+kgp2ryTGXMEAwBsYOZNAEBE7wAYDSB0un80gPuV5Q8APEcBn8HRAN5h5nIAm4log3K9mFP3P24uxvbiMgzslIORz3wb0M4PXIAMZeLoaEU1UpPr1+FlDhRNYWbUMPDLwaNol9MQHSfMQIusdPw4cRiYGW8u3IaUJEJmegrW7CrBDW8U4ut1Ab/viSN6YGCnHMxcsQt/HNoV24vL0LVlFm57Zwn+t6QIn/3fmXh67nrMXlU3973LYssEFyCPhDfpdvfM2uW1u0pQ8NTxxHXDerTE0B4tMOGj5dgyeSTW7S5Bg9RkNG2Yite/24K8ZpnISE3Ggg2B9OgXPvstRvc5AQDwwpcb66Rf+deXG7EmzAX95+0H8fP2gwACOalyMtNQdKAMM1bswrw1e9C0QSp+3bcNGqQlY39pBfI7ZKO0vAr3f7oKyY1btFX7fwzPERDRGAAFzHyDsn4VgIHMfGvIMSuUY3Yo6xsBDERAOfzAzG8q218B8DkzfxDtnrkde3LmZY+p7rtyYHvkZKbhmXkbMKBjDpo0SEWbpg3QPqchpny9ybLJ2yHdcnUXgREEQbCTopduPly5f3uT8O2e8RoionEAxgFAcuNcZEY47s2F22qXf9xcbINkAUQJCILgerhG1V5lxmRxEYB2IettlW2qxxBRCoAmAPZrPBcAwMxTmDmfmfMzsuqHmQfp1jIL/Toc39+0QSo65WaiQKm5ahVm+TILgiBYiOq0lBkjgkUAuhJRRwQa8csBXBF2zDQAYxGw/Y8BMI+ZmYimAXibiJ4AcAKArgB+RAy6t8rCXy8+CYVbinHVoDxc8u/vUVFVg2X3n19bo3Z7cRmaNkxFVkbd3CF7S8rBzMhMT8Huw8ewvOgQRvdpU5t9cM2DBQCAcf9ZjM65mejaIgt3fbwcHZo1xNb9AQ+g24Z1Rc/WjfGvLzfi4V+fhJ2HjuLc7i1w4t2fo7Ka8faNA3HrWz9rLskoCIJ3GT+8OyZ/vqZ2/aQ2TTCgYw5e+XYzFt41FOt3l6BhegraZjfArW//jNM65aCiirF212HMVywJw3q0wJzV9XNytctugO0HItflGN67FU5u2xTbikvx07aDWKvMJ/RonYWiA0eRkZqMgt6tsOPAUSXnl3oRWVPiCIhoBICnACQDeJWZHyaiSQAKmXkaEWUA+A+AvgCKAVweMrk8EcB1AKoA3MbMn8e6X3gcweZ9pchMT0aLrPhT927bX4ZmjdKQGeaju/9IOfo9NAcrHwh4ti7cvB8nt22K5hEyY9bUMJKSCKXlVThWWY1+D7mzMpkgCPGRlpKE78efW/tub5k8El+v24OU5CSc2DILqclJaJyRgq37y5DXPJIRG1hRdAgXPvstXrumP37adgDPztuA35zaBoO75uI2JWPvf64fgD/+92ccKDvuJdU4IwW3DOmMv89cWyee4UBpBf799Uac17NVHasIEHCSWV50CKe0y1aNI5CAMovxe55zQbCCJg1SbXchHd67FT5fESjZuWXySGzaewRLdxzEr/uqOuLEpKq6BvdNW4l7LuyJ3YeP4bZ3l+CZy/uiXU5DMDM6TpiBz/7vTLRukoF+D83B45ecgtvfX4o5fzkbHZtn4lhldb2OaywiBZSJIrCYbnd/XicPvyAIxmnZON3WoLJFE4chMz0ZPe+dBcCeyOKftx1An3ZNQUQoOngUJzTJwNrdJejeqnHc17Q0sliITPdWWU6LIAgJw4Th3QEAb1w30NL7jDurU+3y2zcORG5WOhqmpeCOgm6YdusZlt47SN/22bUp2ts0bQAiMqQEouEZ91Gv0ja7IZbukIyasXjrhoHo1yEbz83bgOfmu6+Yj92c3LYJlslzU4eczDTcdHZn3HR259ptE0f0wJNz1qGswtz62neN6IHTOuWgXXbD2oRzAPD7IV1MvY9bkBGB4AjBkdJ4pYd3RpfmyEhNRo/W1vR4vEbDNH/30aZeN6D22QhSHRaaP/f2s3HtGXmmKYHGGYHvPGj2Obd7yzpKIJERRWA1kkxGlQt6tcKLV56KGwd3wuZHR9RuH3FSK6yaZG/uwbxmDW29Xywu6dcW+R0ix8r4gf552bj57M64elCH2m3h85mdcxvVSyFjhJPaNrHF9u9GRBFYTEsDLq1W8fq1/Z0WAZv3laKgd2skJ1GdUpVEhIZpKdgyeSQuy28X5QrmcXOIqcENPHbJKWjq09q5QVKSAk3TpNG9sWD8ubhxcEc8MLqX6rHpKeY0Y5MvPtmU63gRUQQWc0dBN6dFqEdflQLgdvPpstjlIRumWxetPfvPZ2FAxxwAzlVFa9Ig0NgX9GqFMf3quiCm+DxVbXLI/9+maQNMHNkzopvmSW3qpc6Ji3Y57hoZ2okoAovJcGHqicw0+2T6YcJQFN49LK5zj5o8ARhkzYMF6NoyCy9dnY+v/3YO1u0uiX2SBXz2f2cCAG48q2MdEwgAX08Uz7xtcB1FEItWTdw36vYaogh8RM/WjTH1ugG1dtW/XRB9tHLrOcY9JFo1yUDzRunY+MiI2vQdAHDlwA5RzgrwzqLthu+vRlA5N2mQivbNGjqirBtnpCBJaez2H6nAyW2b4sbBnXDlaYHvZfUuZ5STG9DrIvnQRb0N39OvcwNBRBHYQE+XeMI8evFJOPvEXACBRv6KAe2jHh8tPF4vyUmEJGUu4MNbBuG+X4XXLnKOOwu6xz7IZGoYSFa+j1TFxj1xZI/aRs3fhiF9NG2YJkWfDCKKwAZm/Gmw0yIACPimB/nrBd2QnZkW9fg0kybhggSH+7mNMjR5e1hhJ3/p6npBlbWJCu2kqqam1qTRPLN+3qrg/IUbCc5tuAlSz6UmaEQUgU189bchToug+rJ88oe6UZKvXdsfH94yCACQlhz/y6VWYDvYridpfOomjOgR9/0jofYf1TiQZiXoE7/mwQKc1Lb+ZGf4nIHThKZxtzLHT7wmmvAYAz24wYvOaUQR2ER4Omy3cEq7pnXW++floE+7gFdRitYWWwVWKcIYVERae29m9fEmX3xS7XJFdf28T+0d8BYJfreR5ieMfPeCdkae3BpDurVwWgzHkafNJrxiw8xISaqVVY/nRigf3DwINVHy7GVlaIuaNeM769chG5f1jx6PkKThRlk6szzG4sWr+kXdn5Hqrlczp1F0M6JX6eBjl9FQ3PW0JTBaGhsrGKjT1pySnBTSc4/vngfKKiOaW7ZMHmmrTT40YK1Jg9TayfJwbhvWNep17jV5cjuSHEEau8QOv/GRQNR344xUyz1rnEjQaMCilFCIIrCJZIcms/KaZWJMv7b49NYzox6n1vtPilPm1GQyxe5uxjsa/LfO6NIML12dHzF/+zWn50W9jtkT57FwS/xJ8Ln4cfN+AMCHt5yOId2iK7F4+dwBpwovpuG3An9ntvIBD4zupalR6Z+XjXfGDaqzLTdLvQpbLJiNTd6FXscoQWX21g2nabpXVkYKSo5VWSKL12ib3aB2+adtBwEETG1GOzV/PLcLnplXP8OsE54/PvxZVZERgU3orSRkFlqUwMK7hmJKmFvluoeGx50JtE12Awzp1gJnxTB/2IHWtiU7Mw1f/nUI+uepm9LUJr/18PHvT69dXvtQQZQjnSeYFuXpy/vWbrvnQvNMY38+78R6235zanxVvoxSI7YhAAYVARHlENFsIlqv/FVNYkNEY5Vj1hPR2JDtXxLRWiJaonwSevq+a4tGtt6vo8aAsJaNM+rZ7Y2YQk5smYVXr+mPN64bEPc1APt7a3nNMyN6KhkdEYTmd4rX5KYFM8wrGSmBzkOLkBFhaOyA0d9FrefvVGxCy8aSngIwPiIYD2AuM3cFMFdZrwMR5QC4D8BAAAMA3BemMH7HzH2Uzx6D8riaKwZGj+Q1m9Ym5mAZ2t1+HW2G/TbVxDTFZmGlIjDSoAbTccea37HCrt7zBGei7687s6Mj93UbRt+S0QCmKstTAVykcswFAGYzczEzHwAwG4C7x8YWYbeducrEYa8Tc92X9DOehvrB0cbz0ADGfrvwdNrxuuVqwYiSOTWkBsIrY/PRpmkD1ePOtaBTEJ591S6s/C28hFFF0JKZdyrLuwC0VDmmDYDQ7GE7lG1BXlPMQvdQgseJN7PZF7tdtrd9pJuYkJO/ReP4JrzNpKD38ajc+y3OsaS3XQt1CHjk14HAO2ZgaI+WEV2erxqUhykhcRBf/nUI/vcHe+r4GsXseJBEIaYiIKI5RLRC5TM69DgOjBf19pt+x8wnARisfK6KIsc4IiokosK9e/fqvI07GHXKCbbda1iPlnjkYnN6w4D2OYNmMfIX2U16ij43zBtDipaHYtbYysyKWmro7UsF8zllpafUOhaET4w/NubkOsoMqDvyaJCW7Jn6CTI1rE7Mp5KZhzFzb5XPJwB2E1FrAFD+qtn4iwCEjo3bKtvAzMG/JQDeRmAOIZIcU5g5n5nzc3Od90aJBzsHPFkZKbobQTWG9QiYAbTK/s9LTzF8Tyc5rVOzetsu6nMCBndtHvc1gw3ruLM64byeaoNm89Br6khR8klVR7F9XZLfrl7uqND7EOm/75+GHg/gs9MsdP8o9Spnfsdo92QagKAX0FgAn6gcMwvA+USUrUwSnw9gFhGlEFFzACCiVAAXAlhhUB5BwaxEai+P7R+IKNV4ucqqKLklPMqdw7sb8i4J/hR3jeih6zrrHx6u+156O+bBmIC3b4weZxFOaL8gLTlJd2nNUBfSf15iX+fBqbkIt2PUYDYZwHtEdD2ArQAuBQAiygdwMzPfwMzFRPQggEXKOZOUbZkIKIRUAMkA5gB4yaA8gkVoVSztXVYI3gyyGxozd8Ufoa2/n9ZAZ/W5oKmqT0jyQS0/dWhSvKYN09BE4/Oh1aVZsBdDioCZ9wMYqrK9EMANIeuvAng17JhSANEzbwlx40QkbGoy6a4uZSXBUpBGmKQxMjvIWzcMxO9eXlhnm13pKf73hzNAOnO2qkUJa/E2C0+OqtV0GBpbkpOZhuLSCk3nCdbiPidrwRTMzrGvJRiusto9U3ErHrgAvU0oaq63V35Gl/pzCXrNJvFC0O/mq3b8f3/cFvM8M2IhFt89DD/eVa8fGRdPX97HlOv4FVEENhNvIXenuW1Y/bQAbkatME48qE3uXnN6Hp67oi/SoiiJX/dtg6X3no8fJw5FrxOMKyQtEOmv4aDWoB/WUHgmXjUQ2kEhIrQwKbI3r5mYnIwgisBmmjeyx6/d7L65U2m0nUbt97ptWFdcePIJUXvfAzvmoEnDVLTIsi+FAYF0e6bFG1AVrwfcsUprnAn0jlAuzZdJ41BEESQodqfXzbbJ/OEGgtXmpv9RPa/P69f2x0V926ju08vP95yn+dh43Djj1e/xWoZObGlNvi3dJjHT6t8lBhJmJ5jCtFvPRKVKGUgv8tq1/XHta4vqbR87qANG9TmhtrHtEmHexMzSh9k6AvSCjWFuVjr2lpRrOsfukZ5VsTRaLts/L5BC47ExJ6umGvczMiJIUOz2GmqX0xCdcu3NrmoVfdo2Vd3+wOje6Nehbprqn3T02K0m2MvV09RmqAQdanl03GYp1NLDD+auuiS/neQYCkNGBAmKGYVh/Ep2Zhomje6FI+Wxe405mWlo3igd+46UW17KMRbBXrGeTvcNgzvizuHddd/LbYV6wt1Z1ZDGPzIyInCAm85Wz2djJmZmHg0Sq9ylW1hsgmfW1YPy8PshXUyQxj6CE6anRBjRqNEgNblOMJnW893kKgzot/kndnpL/YgiSFCssNef1LZJPW+LrIwUZKS66zFqZpNn1nHc0SgGG7cXfneqoeu0y4kdHR4ccd41Qv9owgq0dPZDG3/RA3UR05AT2NBulFuU8+cfY07BwI7NcPv7SwE4U3BcUCfYuBnJcDrztsER6xCEEqxnPO6sznHfy0wSPIO95YgicIAcG1I1V1iY/C3UFt3W4zUPEonwxjArPQUlGuY5QtGaIiSveabjcyKhaNEDoccM6twMp7Zvapk8XsNdY3qfcL0N5fGszA8ffKGeuqyPZffwEu/eNMiUvEZGCW8MNXWSE6QjrTegrEuLLHz0e28U07EDUQQOYFZxkk8iVIV6ZWw+/nWl9fn8RvcxJ2jKTF6/tr/t9+yc28iUvEZGSQ1znfGTl4w2neef70Mvogg8TPMs9UnRri2y6pQgNJsuuVmWXdso3Vq5VzarCRaZAYABHXO0pc92xzy3YcxIgudnRBF4mEiPfmiDYAUntW3iKvuwn/nNqce9uEJjR967aVBc9Qys5M3rB1p2bdEDxnDXkyLoItLQ3yv1YwVtvHjlqXX+hpKvpE0AUG8UGF572Gn0Fs3Rg5/MYFYgisDDROoF+fmlaKCjiIxXKOjdWlmK/rvqKaBjRe98aPfoOZaqLMxFpaXzI6OGyIgicAmTRusvqk0g3H5e/ToBmSbl4vcas247C00NlpX0Gh00BH+p0aSB+dliX7km+kS9FdHutUgjbwhDioCIcohoNhGtV/5mRzhuJhEdJKLPwrZ3JKKFRLSBiN4lIn+9xSEM61G/AEosiIB+efW/cj09w0QhOYl8N1G8ZfLIiBlQo3HN6XnHzzO5Af3tgHYR91npWSWTxcYwOiIYD2AuM3cFMFdZV+MxAFepbP87gCeZuQuAAwCuNyiPZ4mnr0QQlzi/MUOpgbBQKfGo57kZ0i0XAHD/qF6W2ev/rDJCDWLFKCSIXQWfEhWjimA0gKnK8lQAF6kdxMxzAZSEbqNAGOS5AD6Idb4fsLuQTKLhB3VIBPQ8oXHtMqAvC6hatHlLk0pF1uLgYzzqlBOcu7nHMaoIWjLzTmV5FwA99o1mAA4yczAGfgcA90Uo2UQ8ekDyqxzHUvuzSwg1+QVHgjU6HpxrTs+rs75l8kh0TpAaEkBgpCPER0xFQERziGiFymd06HEc6NJa9jYS0TgiKiSiwr1791p1G9toZkK+IYKMJEb3CfQCO+UmdvHytJQkdA2ZD6gdEei4xvm9WmHVpAvMFcxFxOoWSccpMjEVATMPY+beKp9PAOwmotYAoPzdo+Pe+wE0JaKgi0tbAEVR5JjCzPnMnJ+bm6vjNu5kcVhlq3ja8yQi1cliP9KpeWIrgnUPDccJIVlBGymeYXo7Ag3TrPUoc7Jb0ihD/X9b82CBzZJ4D6OmoWkAxirLYwF8ovVEZQQxH8CYeM5PNKIN8SM+yASkq5QaFBKbLZNH1pqJ0lwWPezkADVSJLUfvej0YvQpmgzgPCJaD2CYsg4iyieil4MHEdE3AN4HMJSIdhBRcHx6J4C/ENEGBOYMXjEoj2eJ9v5EepAb+TReIJTgYN+vFrIWUSZ7/fqdCPox1JIw834AQ1W2FwK4IWRdtXoJM28CMMCIDIlCPLb+YARxw7RklFVUmy2SJ2AAN53VCWd2be60KI6Rk5mG4tIKp8UQPIy7xpU+ZcvkkYZsq3eP7GmaLG7j4lNjO5JNGNEDg7t6f97IbC7NjxzcZRVuy28UikwVR0YUgUvQOyB447rjA6krBrY3WRr30LSBb4PNNXOsUn00eONZnWyWJDK/H+KOkpaCOqIIXMIJTTMwWId5o1+Hut5C5/XUn6LC63TOzcSZXfxrEgry0e9Px/Q/Ol8hDYjcobmjwB1F7gV1RBG4hIZpKfiPjoyQ4YnlXro632yRXM/c24fgEgfMH26je6vG6HWC8xXSAOtrYcTLyJNbY0DHHKfFcC3idiK4GokB8hYtskxOWWGQ684I1Ad//or6tRyE48iIwEHUcudf0Mt/Jh4hsdj4yAjcOLij02IAAE7t0NRpETyBKAIHUZvkvfK0DoaumWg96AT7d3xBchJJOgePIaYhB7mjoBtuOttczw4pUym4gbxmx1N+bHpkhIOSCFqQEYGDpKckm25TTbQCHQn27/iG3w5oVxv5nuRg50TqdWhDFIHLMJoWINHqFYuJwZsQES7vb79Hl9kjbL8gisAjFPRqpek4K6tACYIefFAiImGQOQKXEakDrMU/+9s7z4mYgdGryHjAu+gpmmMWYgqKj8RqNRKINiG55wFtk8BtsxuaX3rQaeS99ixOFE1yc64jNyOKwCMkJ/nzp5Ienndxg2nIiVGJF/Fn6+JBEszioxmZKzaGnvxVZlPtAtPQnNW7bZfBi8gcgUdING8gwV6euPQU2+95y9md0bddU9vvG0ppuT/rdOjFp/1M7zHqlNh5+RMRUX/mcFEf+5+fdjkNbU8K2Kpxep11MQ1pQxSBy+ic20h1+6DOzdDH4d6VE0T0opIRkiaCzgNOBnXZydWD8uqsd22h/j4JdTGkCIgoh4hmE9F65W92hONmEtFBIvosbPvrRLSZiJYonz5G5EkETgjzFgqlfU5DAECrRPMMikKkyeIF48+1WRJv8tBFvbFo4jCnxbCNcIU34qTWDkniLYyOCMYDmMvMXQHMVdbVeAzAVRH2/Y2Z+yifJQblSWiCveMJI/xT5CPSiKBFVrr6DqEOGanJyPXxd5VoKVeswqgiGA1gqrI8FcBFagcx81wAJQbv5TuW338+1j88vHY9RXEh7RCS0Mtv3HpOFwCSekLQhjwm2jCqCFoy805leReAeJLpP0xEy4joSSLyb9clhDsKuuGOgm7IykitEyl8z4U9AADJPnq6w//Tm6X2raADGRFoI6b7KBHNAaCW6GZi6AozMxHpnaKfgIACSQMwBcCdACZFkGMcgHEA0L594hZrB4DfD+miur1pw0Ahd1+5koa9yE5EqwrepWlDyb2lhZiKgJkjzjQR0W4ias3MO4moNYA9em4eMpooJ6LXAPw1yrFTEFAWyM/P93VrkJ3pn4fbRypPsIBozhfCcYyahqYBGKssjwXwiZ6TFeUBChh8LwKwwqA8Cc+mR0agdRP/PNwtwvzCfd0DEDQx67aznBbBcxhVBJMBnEdE6wEMU9ZBRPlE9HLwICL6BsD7AIYS0Q4iukDZ9RYRLQewHEBzAA8ZlCfh8Ys/eJDf9m+PxXcfH5SKZUiIhVotcCE6hlJMMPN+AENVthcCuCFkfXCE88UZXIhKUhKhWaOQUYEoAiEGMj+sH4ksFjxFeqo8skJ0/DZqNgN5qwRPkZGajC2TRzothuBiRA/oRxSBIAgJhZ/ibMxCFIEgCAlFY6nbrRtRBIIgJBQZqcno3irLaTE8hSgCQRAEnyOKQPAMF/U5wWkRBI8gSQn1IYpA8AwN06WyqqANUQP6EEUgCELCIQMCfYgiEDyDvNuCVkQR6EMUgeAZ5OUWtBKpxKmgjhhdBUFIOP5wThes/OWQ02J4BlEEgiAkHAW9W6Ggt1o9LUENMQ0JnqBzbibO7JLrtBiCkJDIiEDwBHNvH+K0CIKQsMiIQBAEweeIIhAEQfA5oggEQRB8jiFFQEQ5RDSbiNYrf7NVjulDRN8T0UoiWkZEl4Xs60hEC4loAxG9S0RpRuQRBEEQ9GN0RDAewFxm7gpgrrIeThmAq5m5F4ACAE8RUVNl398BPMnMXQAcAHC9QXkEQRAEnRhVBKMBTFWWpwK4KPwAZl7HzOuV5V8A7AGQS4H0gOcC+CDa+YIgCIK1GFUELZl5p7K8C0DLaAcT0QAAaQA2AmgG4CAzVym7dwBoY1AeQRAEQScx4wiIaA4AtRC9iaErzMxExFGu0xrAfwCMZeYavfnCiWgcgHEA0L59e13nCoIgCJGJqQiYeVikfUS0m4haM/NOpaHfE+G4xgCmA5jIzD8om/cDaEpEKcqooC2AoihyTAEwRbleCRGtjSW7gzQHsM9pIWLgdhndLh8gMpqB2+UD3C+jHvk6qG00Glk8DcBYAJOVv5+EH6B4An0M4A1mDs4HBEcQ8wGMAfBOpPMjsJaZ8w3KbhlEVOhm+QD3y+h2+QCR0QzcLh/gfhnNkM/oHMFkAOcR0XoAw5R1EFE+Eb2sHHMpgLMAXENES5RPH2XfnQD+QkQbEJgzeMWgPIIgCIJODI0ImHk/gKEq2wsB3KAsvwngzQjnbwIwwIgMgiAIgjG8Glk8xWkBYuB2+QD3y+h2+QCR0QzcLh/gfhkNy0fMER19BEEQBB/g1RGBIAiCYBKeUgREVEBEa5XcRGrpLFwlDxFdQ0R7QybJb3BCzjCZXiWiPUS0wmlZgNjyENEQIjoU8h3ea7eMKjK1I6L5RLRKyaH1JzfL4tLvMIOIfiSipYrcD7hZFje+y0GIKJmIfiaiz+K+CDN74gMgGYGI5E4IRCcvBdDTzfIAuAbAc05/d2EynQXgVAArnJZFizwAhgD4zGk5w2RqDeBUZTkLwDqnnkUtsrj0OyQAjZTlVAALAZzmVlnc+C6HyPYXAG8b+Y29NCIYAGADM29i5goEYg9Gizz6YOavARQ7LUcQt8mjBWbeycw/KcslAFbDofQobpJFDxzgiLKaqnwcmbB0kyx6IaK2AEYCeDnWsdHwkiJoA2B7yLrTuYm0yvMbJf32B0TUzh7REo5ByrD9cyLq5bQwoRBRHoC+CPQiHSWGLK77DhWTxhIEMhLMZmbHvkONsrjxXX4KwB0AaoxcxEuKwIt8CiCPmU8GMBvHM7UK2vkJQAdmPgXAswD+56w4xyGiRgA+BHAbMx92sSyu/A6ZuZqZ+yCQXmYAEfV2sSyue5eJ6EIAe5h5sdFreUkRFAEI1cJRcxPZQEx5mHk/M5crqy8D6GeTbAkDMx8ODtuZeQaAVCJq7rBYIKJUBBret5j5IzfL4tbvMAgzHwQwH4F6JY4SSRaXvstnABhFRFsQME2fS0Sqwbux8JIiWASgKwWqmqUBuByBXEeulUdJxBdkFAL2W0EHRNSKlFS1FEhjnoRAwkInZSIE0qGsZuYn3C6LS7/DXFIKVBFRAwDnAVjjVlnc+C4z8wRmbsvMeQi0P/OY+cp4rmU06ZxtMHMVEd0KYBYCHjuvMvNKt8lDRJMAFDLzNAB/JKJRAKoQmBC9xil5gxDRfxHwImlORDsA3MfMjuV4UpMHgck6MPOLCCQlvIWIqgAcBXA5K64SDnIGgKsALFfsygBwl9LbdoUsANoDrv4OWwOYSkTJCCim95g5fvdHC2Rx+7tsJhJZLAiC4HO8ZBoSBEEQLEAUgSAIgs8RRSAIguBzRBEIgiD4HFEEgiAIPkcUgSBEgYiahWSc3EVERcryESJ6wWn5BMEMxH1UEDRCRPcDOMLM/3RaFkEwExkRCEIcKDn+P1OW7yeiqUT0DRFtJaKLiegfRLSciGYqKSBARP2I6CsiWkxEs8KiVQXBMUQRCII5dAZwLgLpB94EMJ+ZT0IgknekogyeBTCGmfsBeBXAw04JKwiheCbFhCC4nM+ZuZKIliOQcmSmsn05gDwA3QD0BjBbSfuTDGCnA3IKQj1EEQiCOZQDADPXEFFlSC6fGgTeMwKwkpkHOSWgIERCTEOCYA9rAeQS0SAgkDraLQViBEEUgSDYgFLOdAyAvxPRUgBLAJzuqFCCoCDuo4IgCD5HRgSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPgcUQSCIAg+RxSBIAiCzxFFIAiC4HNEEQiCIPic/wcvziJ0eY2VRAAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "librosa.display.waveplot(samples_out, sr=sr)" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "special-delicious", - "metadata": {}, - "outputs": [], - "source": [ - "import getpass" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "seasonal-consensus", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['GetPassWarning',\n", - " '__all__',\n", - " '__builtins__',\n", - " '__cached__',\n", - " '__doc__',\n", - " '__file__',\n", - " '__loader__',\n", - " '__name__',\n", - " '__package__',\n", - " '__spec__',\n", - " '_raw_input',\n", - " 'contextlib',\n", - " 'fallback_getpass',\n", - " 'getpass',\n", - " 'getuser',\n", - " 'io',\n", - " 'os',\n", - " 'sys',\n", - " 'termios',\n", - " 'unix_getpass',\n", - " 'warnings',\n", - " 'win_getpass']" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(getpass)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "dress-distinction", - "metadata": {}, - "outputs": [], - "source": [ - "getpass?" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "rental-anthony", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Worker:" - ] - } - ], - "source": [ - "import multiprocessing\n", - "import cProfile\n", - "import time\n", - "\n", - "def worker(num):\n", - " time.sleep(3)\n", - " print('Worker:', num)\n", - "\n", - "def profile_worker(num):\n", - " cProfile.runctx('worker(num)', globals(), locals(), 'profile-%d.out' %num)\n", - "\n", - "\n", - "\n", - "for i in range(5):\n", - " p = multiprocessing.Process(target=profile_worker, args=(i,))\n", - " p.start()" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "separated-restriction", - "metadata": {}, - "outputs": [], - "source": [ - "!ls" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "painted-variable", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(2, 2)\n", - "[ 1 20]\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "l = [(1, 20), (2, 30)]\n", - "scores = np.array(l)\n", - "print(scores.shape)\n", - "print(scores[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "satellite-insider", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0 1]\n" - ] - } - ], - "source": [ - "sort_idx = np.argsort(scores[:, -1])\n", - "print(sort_idx)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "developed-thirty", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 1 20]\n", - " [ 2 30]]\n" - ] - } - ], - "source": [ - "sorted_val_scores = scores[sort_idx][::1]\n", - "print(sorted_val_scores)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "official-bench", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 1 20]\n", - " [ 2 30]]\n" - ] - } - ], - "source": [ - "sorted_val_scores = scores[sort_idx]\n", - "print(sorted_val_scores)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "ranking-camera", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "b'\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x14\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x02\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x1e\\x00\\x00\\x00\\x00\\x00\\x00\\x00'\n", - "[ 1 20 2 30]\n", - "[[ 1 20]\n", - " [ 2 30]]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:1: DeprecationWarning: tostring() is deprecated. Use tobytes() instead.\n", - " \"\"\"Entry point for launching an IPython kernel.\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel_launcher.py:3: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead\n", - " This is separate from the ipykernel package so we can avoid doing imports until\n" - ] - } - ], - "source": [ - "a = scores.tostring()\n", - "print(a)\n", - "b = np.fromstring(a, scores.dtype)\n", - "print(b)\n", - "print(scores)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "breeding-proxy", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "numpy.int16" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.int16" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "coordinate-hungary", - "metadata": {}, - "outputs": [], - "source": [ - "dtype = np.dtype('int16')" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "specified-jackson", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "int16\n", - "16\n" - ] - } - ], - "source": [ - "print(dtype)\n", - "dtype is np.int16\n", - "print(np.iinfo(dtype).bits)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "activated-insight", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/train_test.ipynb b/.notebook/train_test.ipynb deleted file mode 100644 index 67212e50a..000000000 --- a/.notebook/train_test.ipynb +++ /dev/null @@ -1,1887 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "cloudy-glass", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "os.environ['CUDA_VISISBLE_DEVICES'] = '0'" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "grand-stephen", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.0.0\n" - ] - } - ], - "source": [ - "import paddle\n", - "print(paddle.__version__)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "isolated-prize", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "romance-samuel", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'num_samples': 5, 'beam_size': 500, 'num_proc_bsearch': 8, 'num_conv_layers': 2, 'num_rnn_layers': 3, 'rnn_layer_size': 2048, 'alpha': 2.5, 'beta': 0.3, 'cutoff_prob': 1.0, 'cutoff_top_n': 40, 'use_gru': False, 'use_gpu': True, 'share_rnn_weights': True, 'infer_manifest': 'examples/aishell/data/manifest.dev', 'mean_std_path': 'examples/aishell/data/mean_std.npz', 'vocab_path': 'examples/aishell/data/vocab.txt', 'lang_model_path': 'models/lm/common_crawl_00.prune01111.trie.klm', 'model_path': 'examples/aishell/checkpoints/step_final', 'decoding_method': 'ctc_beam_search', 'error_rate_type': 'wer', 'specgram_type': 'linear'}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "import sys\n", - "import argparse\n", - "import functools\n", - "from utils.utility import add_arguments, print_arguments\n", - "parser = argparse.ArgumentParser(description=__doc__)\n", - "add_arg = functools.partial(add_arguments, argparser=parser)\n", - "# yapf: disable\n", - "add_arg('num_samples', int, 5, \"# of samples to infer.\")\n", - "add_arg('beam_size', int, 500, \"Beam search width.\")\n", - "add_arg('num_proc_bsearch', int, 8, \"# of CPUs for beam search.\")\n", - "add_arg('num_conv_layers', int, 2, \"# of convolution layers.\")\n", - "add_arg('num_rnn_layers', int, 3, \"# of recurrent layers.\")\n", - "add_arg('rnn_layer_size', int, 2048, \"# of recurrent cells per layer.\")\n", - "add_arg('alpha', float, 2.5, \"Coef of LM for beam search.\")\n", - "add_arg('beta', float, 0.3, \"Coef of WC for beam search.\")\n", - "add_arg('cutoff_prob', float, 1.0, \"Cutoff probability for pruning.\")\n", - "add_arg('cutoff_top_n', int, 40, \"Cutoff number for pruning.\")\n", - "add_arg('use_gru', bool, False, \"Use GRUs instead of simple RNNs.\")\n", - "add_arg('use_gpu', bool, True, \"Use GPU or not.\")\n", - "add_arg('share_rnn_weights',bool, True, \"Share input-hidden weights across \"\n", - " \"bi-directional RNNs. Not for GRU.\")\n", - "add_arg('infer_manifest', str,\n", - " 'examples/aishell/data/manifest.dev',\n", - " \"Filepath of manifest to infer.\")\n", - "add_arg('mean_std_path', str,\n", - " 'examples/aishell/data/mean_std.npz',\n", - " \"Filepath of normalizer's mean & std.\")\n", - "add_arg('vocab_path', str,\n", - " 'examples/aishell/data/vocab.txt',\n", - " \"Filepath of vocabulary.\")\n", - "add_arg('lang_model_path', str,\n", - " 'models/lm/common_crawl_00.prune01111.trie.klm',\n", - " \"Filepath for language model.\")\n", - "add_arg('model_path', str,\n", - " 'examples/aishell/checkpoints/step_final',\n", - " \"If None, the training starts from scratch, \"\n", - " \"otherwise, it resumes from the pre-trained model.\")\n", - "add_arg('decoding_method', str,\n", - " 'ctc_beam_search',\n", - " \"Decoding method. Options: ctc_beam_search, ctc_greedy\",\n", - " choices = ['ctc_beam_search', 'ctc_greedy'])\n", - "add_arg('error_rate_type', str,\n", - " 'wer',\n", - " \"Error rate type for evaluation.\",\n", - " choices=['wer', 'cer'])\n", - "add_arg('specgram_type', str,\n", - " 'linear',\n", - " \"Audio feature type. Options: linear, mfcc.\",\n", - " choices=['linear', 'mfcc'])\n", - "# yapf: disable\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "timely-bikini", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/fftpack/__init__.py:103: DeprecationWarning: The module numpy.dual is deprecated. Instead of using dual, use the functions directly from numpy or scipy.\n", - " from numpy.dual import register_func\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/scipy/special/orthogonal.py:81: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " from numpy import (exp, inf, pi, sqrt, floor, sin, cos, around, int,\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:108: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " long_ = _make_signed(np.long)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/numba/core/types/__init__.py:109: DeprecationWarning: `np.long` is a deprecated alias for `np.compat.long`. To silence this warning, use `np.compat.long` by itself. In the likely event your code does not need to work on Python 2 you can use the builtin `int` for which `np.compat.long` is itself an alias. Doing this will not modify any behaviour and is safe. When replacing `np.long`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " ulong = _make_unsigned(np.long)\n" - ] - } - ], - "source": [ - "from data_utils.dataset import create_dataloader\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " augmentation_config='{}',\n", - " #max_duration=float('inf'),\n", - " max_duration=27.0,\n", - " min_duration=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=False,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "organized-warrior", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:354: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " if arr.dtype == np.object:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "test Tensor(shape=[5, 6], dtype=int32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[14 , 34 , 322 , 233 , 0 , 0 ],\n", - " [238 , 38 , 122 , 164 , 0 , 0 ],\n", - " [8 , 52 , 49 , 42 , 0 , 0 ],\n", - " [109 , 47 , 146 , 193 , 210 , 479 ],\n", - " [3330, 1751, 208 , 1923, 0 , 0 ]])\n", - "test raw 大时代里的的\n", - "test raw 煲汤受宠的的\n", - "audio len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [163, 167, 180, 186, 186])\n", - "test len Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [4, 4, 4, 6, 4])\n", - "audio Tensor(shape=[5, 161, 186], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [[[ 1.11669052, 0.79015088, 0.93658292, ..., 0. , 0. , 0. ],\n", - " [ 0.83549136, 0.72643483, 0.83578080, ..., 0. , 0. , 0. ],\n", - " [-0.89155018, -0.18894747, -0.53357804, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.33386710, -0.81240511, 0.12869737, ..., 0. , 0. , 0. ],\n", - " [-0.17537928, 0.58380985, 0.70696265, ..., 0. , 0. , 0. ],\n", - " [-0.84175998, 1.22041416, 0.07929770, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[-0.35964420, 0.77392709, 0.71409988, ..., 0. , 0. , 0. ],\n", - " [-0.15990183, 0.42962283, 0.06222462, ..., 0. , 0. , 0. ],\n", - " [-0.31166190, -0.74864638, -0.52836996, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [-0.27546275, 0.32889456, 0.12410031, ..., 0. , 0. , 0. ],\n", - " [ 0.16264282, 0.49418071, -0.15960945, ..., 0. , 0. , 0. ],\n", - " [ 0.12476666, 0.00516864, 1.16021466, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.90202141, 1.48541915, 0.92062062, ..., 0. , 0. , 0. ],\n", - " [ 0.82661545, 1.37171340, 0.86746097, ..., 0. , 0. , 0. ],\n", - " [-0.62287915, -0.48645937, 0.35041964, ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [ 0.07376949, 0.07138316, 0.76355994, ..., 0. , 0. , 0. ],\n", - " [-0.32306790, 0.43247896, 1.27311838, ..., 0. , 0. , 0. ],\n", - " [-0.97667056, 0.60747612, 0.79181534, ..., 0. , 0. , 0. ]],\n", - "\n", - " [[ 0.72022128, 0.95428467, 0.92766261, ..., 0.29105374, -0.45564806, -0.62151009],\n", - " [ 0.42083180, 0.49279949, 0.82724041, ..., -0.17333922, -1.45363355, -0.61673522],\n", - " [-0.76116520, -0.84750438, -0.09512503, ..., -1.01497340, -1.42781055, -0.80859023],\n", - " ...,\n", - " [-0.23009977, 1.06155431, 1.09065628, ..., 0.25581080, 0.53794998, -1.22650719],\n", - " [-1.37693381, 0.30778193, 0.17152318, ..., 0.51650339, 0.25580606, 0.83097816],\n", - " [-1.62180591, 1.30567718, 1.09928656, ..., -0.77590007, 1.27712476, 0.53189957]],\n", - "\n", - " [[ 1.03205252, -0.51535392, 0.21077573, ..., 0.76618457, 1.27425683, 1.52250278],\n", - " [ 0.82059991, 0.43990925, 0.13090958, ..., 0.86662549, 1.01687658, 1.48495352],\n", - " [-0.75489789, -0.01997089, -0.65174174, ..., 0.09061214, -0.55211234, -0.01614586],\n", - " ...,\n", - " [ 0.50985396, 1.84555030, 0.79185146, ..., 1.13666189, 1.19898069, 1.98158395],\n", - " [ 1.98721015, 2.52385354, 1.11714780, ..., 0.19416514, 1.11329341, 0.64460152],\n", - " [ 2.69512844, 1.90993905, 0.50245082, ..., -0.50902629, 0.03333465, -1.24584770]]])\n" - ] - } - ], - "source": [ - " for idx, (audio, audio_len, text, text_len) in enumerate(batch_reader()):\n", - " print('test', text)\n", - " print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[0]))\n", - " print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[-1]))\n", - " print('audio len', audio_len)\n", - " print('test len', text_len)\n", - " print('audio', audio)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "confidential-radius", - "metadata": {}, - "outputs": [], - "source": [ - "# reader = batch_reader()\n", - "# audio, test , audio_len, text_len = reader.next()\n", - "# print('test', text)\n", - "# print('t len', text_len) #[B, T]\n", - "# print('audio len', audio_len)\n", - "# print(audio)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "future-vermont", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "煲汤受宠\n" - ] - } - ], - "source": [ - "print(u'\\u7172\\u6c64\\u53d7\\u5ba0')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dental-sweden", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "sunrise-contact", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "hispanic-asthma", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "hearing-leadership", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "skilled-friday", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "copyrighted-measure", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "employed-lightweight", - "metadata": {}, - "outputs": [], - "source": [ - "from model_utils.network import DeepSpeech2, DeepSpeech2Loss\n", - "\n", - "from data_utils.dataset import create_dataloader\n", - "batch_reader = create_dataloader(\n", - " manifest_path=args.infer_manifest,\n", - " vocab_filepath=args.vocab_path,\n", - " mean_std_filepath=args.mean_std_path,\n", - " augmentation_config='{}',\n", - " #max_duration=float('inf'),\n", - " max_duration=27.0,\n", - " min_duration=0.0,\n", - " stride_ms=10.0,\n", - " window_ms=20.0,\n", - " max_freq=None,\n", - " specgram_type=args.specgram_type,\n", - " use_dB_normalization=True,\n", - " random_seed=0,\n", - " keep_transcription_text=False,\n", - " is_training=False,\n", - " batch_size=args.num_samples,\n", - " sortagrad=True,\n", - " shuffle_method=None,\n", - " dist=False)\n", - "\n", - "\n", - "import paddle\n", - "from paddle import nn\n", - "from paddle.nn import functional as F\n", - "from paddle.nn import initializer as I\n", - "\n", - "import math\n", - "\n", - "def brelu(x, t_min=0.0, t_max=24.0, name=None):\n", - " t_min = paddle.to_tensor(t_min)\n", - " t_max = paddle.to_tensor(t_max)\n", - " return x.maximum(t_min).minimum(t_max)\n", - "\n", - "def sequence_mask(x_len, max_len=None, dtype='float32'):\n", - " max_len = max_len or x_len.max()\n", - " x_len = paddle.unsqueeze(x_len, -1)\n", - " row_vector = paddle.arange(max_len)\n", - " mask = row_vector > x_len # maybe a bug\n", - " mask = paddle.cast(mask, dtype)\n", - " print(f'seq mask: {mask}')\n", - " return mask\n", - "\n", - "\n", - "class ConvBn(nn.Layer):\n", - " def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,\n", - " padding, act):\n", - "\n", - " super().__init__()\n", - " self.kernel_size = kernel_size\n", - " self.stride = stride\n", - " self.padding = padding\n", - "\n", - " self.conv = nn.Conv2D(\n", - " num_channels_in,\n", - " num_channels_out,\n", - " kernel_size=kernel_size,\n", - " stride=stride,\n", - " padding=padding,\n", - " weight_attr=None,\n", - " bias_attr=None,\n", - " data_format='NCHW')\n", - "\n", - " self.bn = nn.BatchNorm2D(\n", - " num_channels_out,\n", - " weight_attr=None,\n", - " bias_attr=None,\n", - " data_format='NCHW')\n", - " self.act = F.relu if act == 'relu' else brelu\n", - "\n", - " def forward(self, x, x_len):\n", - " \"\"\"\n", - " x(Tensor): audio, shape [B, C, D, T]\n", - " \"\"\"\n", - " x = self.conv(x)\n", - " x = self.bn(x)\n", - " x = self.act(x)\n", - "\n", - " x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]\n", - " ) // self.stride[1] + 1\n", - "\n", - " # reset padding part to 0\n", - " masks = sequence_mask(x_len) #[B, T]\n", - " masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]\n", - " x = x.multiply(masks)\n", - "\n", - " return x, x_len\n", - "\n", - "\n", - "class ConvStack(nn.Layer):\n", - " def __init__(self, feat_size, num_stacks):\n", - " super().__init__()\n", - " self.feat_size = feat_size # D\n", - " self.num_stacks = num_stacks\n", - "\n", - " self.conv_in = ConvBn(\n", - " num_channels_in=1,\n", - " num_channels_out=32,\n", - " kernel_size=(41, 11), #[D, T]\n", - " stride=(2, 3),\n", - " padding=(20, 5),\n", - " act='brelu')\n", - "\n", - " out_channel = 32\n", - " self.conv_stack = nn.Sequential([\n", - " ConvBn(\n", - " num_channels_in=32,\n", - " num_channels_out=out_channel,\n", - " kernel_size=(21, 11),\n", - " stride=(2, 1),\n", - " padding=(10, 5),\n", - " act='brelu') for i in range(num_stacks - 1)\n", - " ])\n", - "\n", - " # conv output feat_dim\n", - " output_height = (feat_size - 1) // 2 + 1\n", - " for i in range(self.num_stacks - 1):\n", - " output_height = (output_height - 1) // 2 + 1\n", - " self.output_height = out_channel * output_height\n", - "\n", - " def forward(self, x, x_len):\n", - " \"\"\"\n", - " x: shape [B, C, D, T]\n", - " x_len : shape [B]\n", - " \"\"\"\n", - " print(f\"conv in: {x_len}\")\n", - " x, x_len = self.conv_in(x, x_len)\n", - " for i, conv in enumerate(self.conv_stack):\n", - " print(f\"conv in: {x_len}\")\n", - " x, x_len = conv(x, x_len)\n", - " print(f\"conv out: {x_len}\")\n", - " return x, x_len\n", - " \n", - " \n", - "\n", - "class RNNCell(nn.RNNCellBase):\n", - " r\"\"\"\n", - " Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it \n", - " computes the outputs and updates states.\n", - " The formula used is as follows:\n", - " .. math::\n", - " h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})\n", - " y_{t} & = h_{t}\n", - " \n", - " where :math:`act` is for :attr:`activation`.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " hidden_size,\n", - " activation=\"tanh\",\n", - " weight_ih_attr=None,\n", - " weight_hh_attr=None,\n", - " bias_ih_attr=None,\n", - " bias_hh_attr=None,\n", - " name=None):\n", - " super().__init__()\n", - " std = 1.0 / math.sqrt(hidden_size)\n", - " self.weight_hh = self.create_parameter(\n", - " (hidden_size, hidden_size),\n", - " weight_hh_attr,\n", - " default_initializer=I.Uniform(-std, std))\n", - " # self.bias_ih = self.create_parameter(\n", - " # (hidden_size, ),\n", - " # bias_ih_attr,\n", - " # is_bias=True,\n", - " # default_initializer=I.Uniform(-std, std))\n", - " self.bias_ih = None\n", - " self.bias_hh = self.create_parameter(\n", - " (hidden_size, ),\n", - " bias_hh_attr,\n", - " is_bias=True,\n", - " default_initializer=I.Uniform(-std, std))\n", - "\n", - " self.hidden_size = hidden_size\n", - " if activation not in [\"tanh\", \"relu\", \"brelu\"]:\n", - " raise ValueError(\n", - " \"activation for SimpleRNNCell should be tanh or relu, \"\n", - " \"but get {}\".format(activation))\n", - " self.activation = activation\n", - " self._activation_fn = paddle.tanh \\\n", - " if activation == \"tanh\" \\\n", - " else F.relu\n", - " if activation == 'brelu':\n", - " self._activation_fn = brelu\n", - "\n", - " def forward(self, inputs, states=None):\n", - " if states is None:\n", - " states = self.get_initial_states(inputs, self.state_shape)\n", - " pre_h = states\n", - " i2h = inputs\n", - " if self.bias_ih is not None:\n", - " i2h += self.bias_ih\n", - " h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)\n", - " if self.bias_hh is not None:\n", - " h2h += self.bias_hh\n", - " h = self._activation_fn(i2h + h2h)\n", - " return h, h\n", - "\n", - " @property\n", - " def state_shape(self):\n", - " return (self.hidden_size, )\n", - "\n", - "\n", - "class GRUCellShare(nn.RNNCellBase):\n", - " r\"\"\"\n", - " Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, \n", - " it computes the outputs and updates states.\n", - " The formula for GRU used is as follows:\n", - " .. math::\n", - " r_{t} & = \\sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})\n", - " z_{t} & = \\sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})\n", - " \\widetilde{h}_{t} & = \\tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))\n", - " h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \\widetilde{h}_{t}\n", - " y_{t} & = h_{t}\n", - " \n", - " where :math:`\\sigma` is the sigmoid fucntion, and * is the elemetwise \n", - " multiplication operator.\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " input_size,\n", - " hidden_size,\n", - " weight_ih_attr=None,\n", - " weight_hh_attr=None,\n", - " bias_ih_attr=None,\n", - " bias_hh_attr=None,\n", - " name=None):\n", - " super().__init__()\n", - " std = 1.0 / math.sqrt(hidden_size)\n", - " self.weight_hh = self.create_parameter(\n", - " (3 * hidden_size, hidden_size),\n", - " weight_hh_attr,\n", - " default_initializer=I.Uniform(-std, std))\n", - " # self.bias_ih = self.create_parameter(\n", - " # (3 * hidden_size, ),\n", - " # bias_ih_attr,\n", - " # is_bias=True,\n", - " # default_initializer=I.Uniform(-std, std))\n", - " self.bias_ih = None\n", - " self.bias_hh = self.create_parameter(\n", - " (3 * hidden_size, ),\n", - " bias_hh_attr,\n", - " is_bias=True,\n", - " default_initializer=I.Uniform(-std, std))\n", - "\n", - " self.hidden_size = hidden_size\n", - " self.input_size = input_size\n", - " self._gate_activation = F.sigmoid\n", - " #self._activation = paddle.tanh\n", - " self._activation = F.relu\n", - "\n", - " def forward(self, inputs, states=None):\n", - " if states is None:\n", - " states = self.get_initial_states(inputs, self.state_shape)\n", - "\n", - " pre_hidden = states\n", - " x_gates = inputs\n", - " if self.bias_ih is not None:\n", - " x_gates = x_gates + self.bias_ih\n", - " h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)\n", - " if self.bias_hh is not None:\n", - " h_gates = h_gates + self.bias_hh\n", - "\n", - " x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)\n", - " h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)\n", - "\n", - " r = self._gate_activation(x_r + h_r)\n", - " z = self._gate_activation(x_z + h_z)\n", - " c = self._activation(x_c + r * h_c) # apply reset gate after mm\n", - " h = (pre_hidden - c) * z + c\n", - " # https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru\n", - " #h = (1-z) * pre_hidden + z * c\n", - "\n", - " return h, h\n", - "\n", - " @property\n", - " def state_shape(self):\n", - " r\"\"\"\n", - " The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch\n", - " size would be automatically inserted into shape). The shape corresponds\n", - " to the shape of :math:`h_{t-1}`.\n", - " \"\"\"\n", - " return (self.hidden_size, )\n", - "\n", - "\n", - "class BiRNNWithBN(nn.Layer):\n", - " \"\"\"Bidirectonal simple rnn layer with sequence-wise batch normalization.\n", - " The batch normalization is only performed on input-state weights.\n", - "\n", - " :param name: Name of the layer parameters.\n", - " :type name: string\n", - " :param size: Dimension of RNN cells.\n", - " :type size: int\n", - " :param share_weights: Whether to share input-hidden weights between\n", - " forward and backward directional RNNs.\n", - " :type share_weights: bool\n", - " :return: Bidirectional simple rnn layer.\n", - " :rtype: Variable\n", - " \"\"\"\n", - "\n", - " def __init__(self, i_size, h_size, share_weights):\n", - " super().__init__()\n", - " self.share_weights = share_weights\n", - " if self.share_weights:\n", - " #input-hidden weights shared between bi-directional rnn.\n", - " self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", - " # batch norm is only performed on input-state projection\n", - " self.fw_bn = nn.BatchNorm1D(\n", - " h_size, bias_attr=None, data_format='NLC')\n", - " self.bw_fc = self.fw_fc\n", - " self.bw_bn = self.fw_bn\n", - " else:\n", - " self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", - " self.fw_bn = nn.BatchNorm1D(\n", - " h_size, bias_attr=None, data_format='NLC')\n", - " self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)\n", - " self.bw_bn = nn.BatchNorm1D(\n", - " h_size, bias_attr=None, data_format='NLC')\n", - "\n", - " self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n", - " self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')\n", - " self.fw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n", - " self.bw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n", - "\n", - " def forward(self, x, x_len):\n", - " # x, shape [B, T, D]\n", - " fw_x = self.fw_bn(self.fw_fc(x))\n", - " bw_x = self.bw_bn(self.bw_fc(x))\n", - " fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n", - " bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n", - " x = paddle.concat([fw_x, bw_x], axis=-1)\n", - " return x, x_len\n", - "\n", - "\n", - "class BiGRUWithBN(nn.Layer):\n", - " \"\"\"Bidirectonal gru layer with sequence-wise batch normalization.\n", - " The batch normalization is only performed on input-state weights.\n", - "\n", - " :param name: Name of the layer.\n", - " :type name: string\n", - " :param input: Input layer.\n", - " :type input: Variable\n", - " :param size: Dimension of GRU cells.\n", - " :type size: int\n", - " :param act: Activation type.\n", - " :type act: string\n", - " :return: Bidirectional GRU layer.\n", - " :rtype: Variable\n", - " \"\"\"\n", - "\n", - " def __init__(self, i_size, h_size, act):\n", - " super().__init__()\n", - " hidden_size = h_size * 3\n", - " self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n", - " self.fw_bn = nn.BatchNorm1D(\n", - " hidden_size, bias_attr=None, data_format='NLC')\n", - " self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)\n", - " self.bw_bn = nn.BatchNorm1D(\n", - " hidden_size, bias_attr=None, data_format='NLC')\n", - "\n", - " self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n", - " self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)\n", - " self.fw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]\n", - " self.bw_rnn = nn.RNN(\n", - " self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]\n", - "\n", - " def forward(self, x, x_len):\n", - " # x, shape [B, T, D]\n", - " fw_x = self.fw_bn(self.fw_fc(x))\n", - " bw_x = self.bw_bn(self.bw_fc(x))\n", - " fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)\n", - " bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)\n", - " x = paddle.concat([fw_x, bw_x], axis=-1)\n", - " return x, x_len\n", - "\n", - "\n", - "class RNNStack(nn.Layer):\n", - " \"\"\"RNN group with stacked bidirectional simple RNN or GRU layers.\n", - "\n", - " :param input: Input layer.\n", - " :type input: Variable\n", - " :param size: Dimension of RNN cells in each layer.\n", - " :type size: int\n", - " :param num_stacks: Number of stacked rnn layers.\n", - " :type num_stacks: int\n", - " :param use_gru: Use gru if set True. Use simple rnn if set False.\n", - " :type use_gru: bool\n", - " :param share_rnn_weights: Whether to share input-hidden weights between\n", - " forward and backward directional RNNs.\n", - " It is only available when use_gru=False.\n", - " :type share_weights: bool\n", - " :return: Output layer of the RNN group.\n", - " :rtype: Variable\n", - " \"\"\"\n", - "\n", - " def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):\n", - " super().__init__()\n", - " self.rnn_stacks = nn.LayerList()\n", - " for i in range(num_stacks):\n", - " if use_gru:\n", - " #default:GRU using tanh\n", - " self.rnn_stacks.append(\n", - " BiGRUWithBN(i_size=i_size, h_size=h_size, act=\"relu\"))\n", - " else:\n", - " self.rnn_stacks.append(\n", - " BiRNNWithBN(\n", - " i_size=i_size,\n", - " h_size=h_size,\n", - " share_weights=share_rnn_weights))\n", - " i_size = h_size * 2\n", - "\n", - " def forward(self, x, x_len):\n", - " \"\"\"\n", - " x: shape [B, T, D]\n", - " x_len: shpae [B]\n", - " \"\"\"\n", - " for i, rnn in enumerate(self.rnn_stacks):\n", - " x, x_len = rnn(x, x_len)\n", - " masks = sequence_mask(x_len) #[B, T]\n", - " masks = masks.unsqueeze(-1) # [B, T, 1]\n", - " x = x.multiply(masks)\n", - " return x, x_len\n", - "\n", - " \n", - "class DeepSpeech2Test(DeepSpeech2):\n", - " def __init__(self,\n", - " feat_size,\n", - " dict_size,\n", - " num_conv_layers=2,\n", - " num_rnn_layers=3,\n", - " rnn_size=256,\n", - " use_gru=False,\n", - " share_rnn_weights=True):\n", - " super().__init__(feat_size,\n", - " dict_size,\n", - " num_conv_layers=2,\n", - " num_rnn_layers=3,\n", - " rnn_size=256,\n", - " use_gru=False,\n", - " share_rnn_weights=True)\n", - " self.feat_size = feat_size # 161 for linear\n", - " self.dict_size = dict_size\n", - "\n", - " self.conv = ConvStack(feat_size, num_conv_layers)\n", - " \n", - "# self.fc = nn.Linear(1312, dict_size + 1)\n", - "\n", - " i_size = self.conv.output_height # H after conv stack\n", - " self.rnn = RNNStack(\n", - " i_size=i_size,\n", - " h_size=rnn_size,\n", - " num_stacks=num_rnn_layers,\n", - " use_gru=use_gru,\n", - " share_rnn_weights=share_rnn_weights)\n", - " \n", - " self.fc = nn.Linear(rnn_size * 2, dict_size + 1)\n", - " \n", - " def infer(self, audio, audio_len):\n", - " # [B, D, T] -> [B, C=1, D, T]\n", - " audio = audio.unsqueeze(1)\n", - "\n", - " # convolution group\n", - " x, audio_len = self.conv(audio, audio_len)\n", - " print('conv out', x.shape)\n", - "\n", - " # convert data from convolution feature map to sequence of vectors\n", - " B, C, D, T = paddle.shape(x)\n", - " x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]\n", - " x = x.reshape([B, T, C * D]) #[B, T, C*D]\n", - " print('rnn input', x.shape)\n", - "\n", - " # remove padding part\n", - " x, audio_len = self.rnn(x, audio_len) #[B, T, D]\n", - " print('rnn output', x.shape)\n", - "\n", - " logits = self.fc(x) #[B, T, V + 1]\n", - "\n", - " #ctcdecoder need probs, not log_probs\n", - " probs = F.softmax(logits)\n", - "\n", - " return logits, probs, audio_len\n", - "\n", - " def forward(self, audio, audio_len, text, text_len):\n", - " \"\"\"\n", - " audio: shape [B, D, T]\n", - " text: shape [B, T]\n", - " audio_len: shape [B]\n", - " text_len: shape [B]\n", - " \"\"\"\n", - " return self.infer(audio, audio_len)\n", - " \n", - "\n", - "feat_dim=161\n", - "\n", - "model = DeepSpeech2Test(\n", - " feat_size=feat_dim,\n", - " dict_size=batch_reader.dataset.vocab_size,\n", - " num_conv_layers=args.num_conv_layers,\n", - " num_rnn_layers=args.num_rnn_layers,\n", - " rnn_size=1024,\n", - " use_gru=args.use_gru,\n", - " share_rnn_weights=args.share_rnn_weights,\n", - " )\n", - "dp_model = model\n", - "#dp_model = paddle.DataParallel(model)\n", - "\n", - "loss_fn = DeepSpeech2Loss(batch_reader.dataset.vocab_size)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "divided-incentive", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "discrete-conjunction", - "metadata": {}, - "outputs": [], - "source": [ - "audio, audio_len, text, text_len = None, None, None, None\n", - "\n", - "for idx, inputs in enumerate(batch_reader):\n", - " audio, audio_len, text, text_len = inputs\n", - "# print(idx)\n", - "# print('a', audio.shape, audio.place)\n", - "# print('t', text)\n", - "# print('al', audio_len)\n", - "# print('tl', text_len)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "protected-announcement", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "conv in: Tensor(shape=[5], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,\n", - " [163, 167, 180, 186, 186])\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "conv in: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [55, 56, 60, 62, 62])\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "conv out: Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [55, 56, 60, 62, 62])\n", - "conv out [5, 32, 41, 62]\n", - "rnn input [5, 62, 1312]\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", - " return (isinstance(seq, collections.Sequence) and\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "seq mask: Tensor(shape=[5, 62], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", - " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", - "rnn output [5, 62, 2048]\n", - "logits len Tensor(shape=[5], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [55, 56, 60, 62, 62])\n", - "loss Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [2316.82153320])\n" - ] - } - ], - "source": [ - "outputs = dp_model(audio, audio_len, text, text_len)\n", - "logits, _, logits_len = outputs\n", - "print('logits len', logits_len)\n", - "loss = loss_fn.forward(logits, text, logits_len, text_len)\n", - "print('loss', loss)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "universal-myrtle", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: None\n", - "param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: None\n", - "param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: None\n", - "param grad: fc.bias: shape: [4299] stop_grad: False grad: None\n" - ] - } - ], - "source": [ - "for n, p in dp_model.named_parameters():\n", - " print(\n", - " f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "referenced-double", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "param grad: conv.conv_in.conv.weight: shape: [32, 1, 41, 11] stop_grad: False grad: [[[[ 2.1243238 1.696022 3.770659 ... 5.234652 5.4865217\n", - " 4.757795 ]\n", - " [ 2.651376 2.3109848 4.428488 ... 5.353201 8.703288\n", - " 5.1787405 ]\n", - " [ 2.7511077 1.8823049 2.1875212 ... 3.4821286 6.386543\n", - " 3.5026932 ]\n", - " ...\n", - " [ 1.9173846 1.8623551 0.5601456 ... 2.8375719 3.8496673\n", - " 2.359191 ]\n", - " [ 2.3827765 2.497965 1.5914664 ... 2.220721 3.4617734\n", - " 4.829253 ]\n", - " [ 1.6855702 1.5040786 1.8793598 ... 4.0773935 3.176893\n", - " 3.7477999 ]]]\n", - "\n", - "\n", - " [[[ 1.8451455 2.0091445 1.5225713 ... 1.524528 0.17764974\n", - " 1.0245132 ]\n", - " [ 1.9388857 1.3873467 2.044691 ... 0.92544 -0.9746763\n", - " -0.41603735]\n", - " [ 2.6814485 2.6096234 1.6802506 ... 1.902397 1.6837387\n", - " -0.96788657]\n", - " ...\n", - " [ 4.3675485 1.9822174 1.1695029 ... 1.4672399 3.2029557\n", - " 2.6364415 ]\n", - " [ 3.2536 1.1792442 -0.5618002 ... 2.101127 1.904225\n", - " 3.3839993 ]\n", - " [ 1.9118482 1.0651072 0.5409893 ... 2.6783593 1.6871439\n", - " 4.1078367 ]]]\n", - "\n", - "\n", - " [[[-4.412424 -1.7111907 -1.7722387 ... -4.3383503 -6.2393785\n", - " -6.139402 ]\n", - " [-2.260428 -1.0250616 -2.0550888 ... -5.353946 -4.29947\n", - " -6.158736 ]\n", - " [-1.4927872 0.7552787 -0.0702923 ... -4.485656 -4.0794134\n", - " -5.416684 ]\n", - " ...\n", - " [ 2.9100134 4.156195 4.357041 ... -3.569804 -1.8634341\n", - " -0.8772557 ]\n", - " [ 1.6895763 3.4314504 4.1192107 ... -1.380024 -2.3234155\n", - " -3.6650617 ]\n", - " [ 2.4190075 1.007498 3.1173465 ... -0.96318084 -3.6175003\n", - " -2.5240796 ]]]\n", - "\n", - "\n", - " ...\n", - "\n", - "\n", - " [[[-0.6865506 -0.60106415 -1.5555015 ... 2.0853553 1.900961\n", - " 2.101063 ]\n", - " [-0.31686288 -1.4362946 -1.4929098 ... 0.15085456 1.4540495\n", - " 1.4128599 ]\n", - " [-0.57852304 -0.8204216 -2.3264258 ... 1.4970423 0.54599845\n", - " 1.6222539 ]\n", - " ...\n", - " [ 0.32624918 0.96004546 -0.7476514 ... 2.2786083 2.1000178\n", - " 2.7494807 ]\n", - " [-1.6967826 -0.78979015 -1.8424999 ... 1.0620685 2.0544293\n", - " 2.2483966 ]\n", - " [ 0.8192332 2.601636 -2.6636481 ... 0.26625186 1.7610842\n", - " 1.7467536 ]]]\n", - "\n", - "\n", - " [[[ 0.9140297 0.42424175 1.4352363 ... -2.3022954 -3.001058\n", - " -2.6987422 ]\n", - " [ 0.4491998 -0.10698095 1.5089144 ... -3.2831016 -3.6055021\n", - " -3.6595795 ]\n", - " [ 2.6818252 -1.5750014 -0.34812498 ... -4.4137015 -4.250422\n", - " -3.481941 ]\n", - " ...\n", - " [ 1.4232106 2.9689102 3.9547806 ... -0.481165 0.28190404\n", - " -1.2167063 ]\n", - " [ 2.2297084 4.8198485 4.2857304 ... 0.57483846 1.4093391\n", - " 0.0715822 ]\n", - " [ 1.679745 4.768068 5.416195 ... 0.17254728 0.4623217\n", - " 1.4772662 ]]]\n", - "\n", - "\n", - " [[[-2.0860114 -2.9508173 -1.4945896 ... -4.067145 -2.5652342\n", - " -3.5771027 ]\n", - " [-2.697845 -1.9273603 -2.3885014 ... -2.196533 -2.8573706\n", - " -2.0113711 ]\n", - " [-2.413383 -2.7204053 -1.0502659 ... -3.001385 -3.36447\n", - " -4.3225455 ]\n", - " ...\n", - " [ 1.2754489 0.9560999 1.5239805 ... -0.0105865 -1.00876\n", - " 2.6247358 ]\n", - " [ 1.1965859 1.0378222 1.1025598 ... -0.5394704 0.49838027\n", - " -0.9618193 ]\n", - " [ 1.1361816 1.3232857 0.687318 ... -0.23925456 -0.43679112\n", - " -0.79297894]]]]\n", - "param grad: conv.conv_in.conv.bias: shape: [32] stop_grad: False grad: [ 5.9604645e-07 -3.9339066e-06 -1.0728836e-06 -1.6689301e-06\n", - " 1.1920929e-06 -2.5033951e-06 -2.3841858e-07 4.7683716e-07\n", - " 4.2915344e-06 -1.9073486e-06 -1.9073486e-06 3.0994415e-06\n", - " -2.6822090e-06 3.3378601e-06 -4.2915344e-06 5.2452087e-06\n", - " 3.8146973e-06 2.3841858e-07 7.1525574e-07 -3.6954880e-06\n", - " 2.0563602e-06 -2.6226044e-06 3.0994415e-06 -3.5762787e-07\n", - " -4.7683716e-06 1.2218952e-06 3.3378601e-06 -2.5629997e-06\n", - " 2.3841858e-07 -1.7881393e-06 4.7683716e-07 -2.7418137e-06]\n", - "param grad: conv.conv_in.bn.weight: shape: [32] stop_grad: False grad: [ 2.363316 3.286464 1.9607866 -1.6367784 -1.6325372 -1.7729434\n", - " -0.9261875 2.0950415 0.1155543 -0.8857083 0.70079553 0.33920464\n", - " 2.6953902 -0.64524114 0.8845749 -1.2271115 0.6578167 -2.939814\n", - " 5.5728893 -1.0917969 0.01470797 1.395206 4.8009634 -0.744532\n", - " 0.944651 -1.092311 1.4877632 -3.042566 0.51686054 -5.4768667\n", - " -5.628145 -1.0894046 ]\n", - "param grad: conv.conv_in.bn.bias: shape: [32] stop_grad: False grad: [ 1.5193373 1.8838218 3.7722278 0.28052303 0.5386534 -0.44620085\n", - " -1.6977876 3.115642 0.03312349 -2.9121587 3.8925257 0.2288351\n", - " -2.273387 -1.3597974 4.3708124 -0.23374033 0.116272 -0.7064927\n", - " 6.5267463 -1.5318865 1.0288429 0.7928574 -0.24655592 -2.1116853\n", - " 2.922772 -3.3462617 1.7016437 -3.5471547 0.29777628 -3.2820854\n", - " -4.116946 -0.9909375 ]\n", - "param grad: conv.conv_in.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_in.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.conv.weight: shape: [32, 32, 21, 11] stop_grad: False grad: [[[[ 6.20494843e-01 5.95983505e-01 -1.48909020e+00 ... -6.86620831e-01\n", - " 6.71104014e-01 -1.95339048e+00]\n", - " [-3.91837955e-03 1.27062631e+00 -1.63248098e+00 ... 1.07290137e+00\n", - " -9.42245364e-01 -3.34277248e+00]\n", - " [ 2.41821265e+00 2.36212373e-01 -1.84433365e+00 ... 1.23182368e+00\n", - " 1.36039746e+00 -2.94621849e+00]\n", - " ...\n", - " [ 1.55153418e+00 7.25861669e-01 2.08785534e+00 ... -6.40172660e-01\n", - " -3.23889256e-02 -2.30832791e+00]\n", - " [ 3.69824195e+00 1.27163112e-01 4.09263194e-01 ... -8.60729575e-01\n", - " -3.51897454e+00 -2.10093403e+00]\n", - " [-4.94779050e-01 -3.74262631e-01 -1.19801068e+00 ... -2.05930543e+00\n", - " -7.38576293e-01 -9.44581270e-01]]\n", - "\n", - " [[-2.04341412e+00 -3.70606273e-01 -1.40429378e+00 ... -1.71711946e+00\n", - " -4.09437418e-01 -1.74107194e+00]\n", - " [-8.72247815e-01 -1.06301677e+00 -9.19306517e-01 ... -2.98976970e+00\n", - " -3.03250861e+00 -2.37099743e+00]\n", - " [-5.00457406e-01 -1.11882675e+00 -5.91526508e-01 ... 4.23921436e-01\n", - " -2.08650708e+00 -1.82109618e+00]\n", - " ...\n", - " [ 2.07773042e+00 1.40735030e-01 -2.60543615e-01 ... -1.55956164e-01\n", - " -1.31862307e+00 -2.07174897e+00]\n", - " [ 7.95007765e-01 1.14988625e-01 -1.43308258e+00 ... 8.29253554e-01\n", - " -9.57888126e-01 -3.82121086e-01]\n", - " [ 8.34397674e-02 1.38636863e+00 -1.21593380e+00 ... -2.65783578e-01\n", - " 1.78124309e-02 -3.40287232e+00]]\n", - "\n", - " [[ 6.27344131e-01 5.71699142e-02 -3.58010936e+00 ... -4.53077674e-01\n", - " 1.65331578e+00 2.58466601e-02]\n", - " [ 2.66681361e+00 2.02069378e+00 -1.52052927e+00 ... 2.94914508e+00\n", - " 1.94632411e+00 -1.06698799e+00]\n", - " [ 1.57839453e+00 -1.03649735e-01 -4.22528505e+00 ... 2.28863955e+00\n", - " 4.27859402e+00 3.66381669e+00]\n", - " ...\n", - " [-2.44603205e+00 -2.09621000e+00 -2.57623529e+00 ... 9.00211930e-01\n", - " 4.30536079e+00 -2.49779320e+00]\n", - " [-2.52187514e+00 -3.36546659e+00 -1.26748765e+00 ... 8.11533451e-01\n", - " 2.55930424e-01 4.50821817e-02]\n", - " [-3.40082574e+00 -3.26924801e+00 -5.86932135e+00 ... -1.18203712e+00\n", - " 1.09565187e+00 -4.96661961e-01]]\n", - "\n", - " ...\n", - "\n", - " [[ 8.20469666e+00 6.96195841e+00 2.73753977e+00 ... 8.34498823e-01\n", - " 2.56748104e+00 1.67592216e+00]\n", - " [ 9.85801792e+00 8.81465149e+00 6.09280396e+00 ... 1.42389655e+00\n", - " 2.92086434e+00 2.08308399e-01]\n", - " [ 8.00702763e+00 7.97301006e+00 4.64527416e+00 ... 8.61916900e-01\n", - " 3.55370259e+00 4.75085378e-01]\n", - " ...\n", - " [ 5.61662769e+00 -4.72857296e-01 -1.04519971e-01 ... -4.03000236e-01\n", - " -1.66419971e+00 -1.70375630e-01]\n", - " [ 4.52409792e+00 -3.70670676e-01 4.54190969e-02 ... -8.20453286e-01\n", - " 9.49141383e-02 8.88008535e-01]\n", - " [ 3.27219462e+00 8.93201411e-01 1.94810414e+00 ... -2.86915004e-02\n", - " 1.93200278e+00 8.19505215e-01]]\n", - "\n", - " [[ 5.84066296e+00 6.72855520e+00 5.21399307e+00 ... 4.55058670e+00\n", - " 3.19132543e+00 3.17435169e+00]\n", - " [ 6.04594421e+00 6.88997173e+00 5.00542831e+00 ... 2.23561144e+00\n", - " 2.76059532e+00 4.83479440e-01]\n", - " [ 5.36118126e+00 4.13896275e+00 3.68701124e+00 ... 3.64462805e+00\n", - " 2.80596399e+00 1.52781498e+00]\n", - " ...\n", - " [ 2.87856674e+00 5.84320784e-01 1.74297714e+00 ... 2.83938944e-01\n", - " -2.26546407e-01 -1.18434143e+00]\n", - " [ 2.08510804e+00 1.74915957e+00 1.58637917e+00 ... 6.41967297e-01\n", - " -1.31319761e-01 -3.85830402e-01]\n", - " [ 4.41666174e+00 2.58244562e+00 2.97712159e+00 ... 1.42317235e-01\n", - " 1.68037796e+00 -6.50003672e-01]]\n", - "\n", - " [[ 1.05511594e+00 6.74880028e-01 -7.64639139e-01 ... -2.15282440e-01\n", - " 2.07197094e+00 4.48752761e-01]\n", - " [ 2.12095881e+00 3.44118834e+00 1.61375272e+00 ... -1.18487728e+00\n", - " 1.88659012e+00 1.48252523e+00]\n", - " [ 8.33427787e-01 4.35035896e+00 -3.59877385e-02 ... 8.70242774e-01\n", - " 3.75945044e+00 -3.09408635e-01]\n", - " ...\n", - " [ 5.08510351e+00 4.73114061e+00 1.97346115e+00 ... -2.25924397e+00\n", - " -1.26373076e+00 -1.37826729e+00]\n", - " [ 6.17275095e+00 4.16016817e+00 3.15675950e+00 ... -2.02416754e+00\n", - " 1.50002241e-02 1.84633851e+00]\n", - " [ 7.32995272e+00 5.34601831e+00 4.58857203e+00 ... -1.88874304e+00\n", - " 1.53240371e+00 7.47349262e-02]]]\n", - "\n", - "\n", - " [[[-1.80918843e-01 -2.52616453e+00 -2.78145695e+00 ... 1.44283652e+00\n", - " -1.08945215e+00 4.19084758e-01]\n", - " [-9.66833949e-01 -2.41106153e+00 -3.48886085e+00 ... -1.87193304e-01\n", - " 8.21905077e-01 1.89097953e+00]\n", - " [-1.59118319e+00 -2.56997013e+00 -3.10426521e+00 ... 2.05900550e+00\n", - " -2.78253704e-01 6.96343541e-01]\n", - " ...\n", - " [ 6.66302443e-02 -2.00887346e+00 -3.17550874e+00 ... 7.97579706e-01\n", - " -9.71581042e-02 1.71877682e+00]\n", - " [-8.01679730e-01 -2.02678037e+00 -3.21915555e+00 ... 8.35528374e-01\n", - " -1.15296638e+00 4.35728967e-01]\n", - " [ 1.45292446e-01 -2.15479851e+00 -1.51839817e+00 ... -3.07936192e-01\n", - " -5.39051890e-01 1.13107657e+00]]\n", - "\n", - " [[-2.43341160e+00 -3.35346818e+00 -9.87014294e-01 ... 1.34049034e+00\n", - " 2.95773447e-02 1.27177119e+00]\n", - " [-2.61602497e+00 -9.76761580e-01 -2.52060473e-01 ... -1.38134825e+00\n", - " 3.85564029e-01 4.57195908e-01]\n", - " [-2.23676014e+00 -4.00404739e+00 -2.23409963e+00 ... -1.41846514e+00\n", - " -6.58698231e-02 -3.61778140e-01]\n", - " ...\n", - " [-1.13604403e+00 -6.03917837e-02 -4.95491922e-01 ... 2.14673686e+00\n", - " 1.21484184e+00 2.22764325e+00]\n", - " [-1.05162430e+00 -1.59828448e+00 3.15489501e-01 ... 2.28046751e+00\n", - " 2.39702511e+00 2.43942714e+00]\n", - " [-1.27370405e+00 -2.05736399e-01 -1.12124372e+00 ... 2.21597219e+00\n", - " 2.50086927e+00 1.91134131e+00]]\n", - "\n", - " [[-4.53170598e-01 -1.59644139e+00 -3.63470483e+00 ... -4.35066032e+00\n", - " -3.79540777e+00 -1.09796596e+00]\n", - " [-2.21036464e-01 -2.53353834e+00 -1.28269875e+00 ... -3.38615727e+00\n", - " -2.59143281e+00 7.74220943e-01]\n", - " [-6.89323783e-01 -1.44375205e+00 6.66438341e-02 ... -1.30736077e+00\n", - " -1.23293114e+00 1.58148706e+00]\n", - " ...\n", - " [ 1.63751483e+00 -4.08427984e-01 -8.15176964e-01 ... 3.70807743e+00\n", - " 2.04232907e+00 1.97716308e+00]\n", - " [ 2.13261342e+00 1.85947633e+00 -8.06532025e-01 ... 1.98311245e+00\n", - " 2.27003932e+00 -1.11734614e-01]\n", - " [ 1.28702402e+00 3.98628891e-01 -1.63712263e+00 ... 8.00528765e-01\n", - " 5.78273535e-01 -2.59924948e-01]]\n", - "\n", - " ...\n", - "\n", - " [[ 3.96233416e+00 4.66794682e+00 1.39437711e+00 ... 7.52061129e-01\n", - " -1.53534544e+00 -6.67162359e-01]\n", - " [ 2.33841681e+00 3.35811281e+00 9.80114818e-01 ... 1.48806703e+00\n", - " 2.68609226e-01 -1.35124445e+00]\n", - " [ 2.08177710e+00 4.28519583e+00 1.52450514e+00 ... 7.45321214e-01\n", - " -5.04359961e-01 -1.81241560e+00]\n", - " ...\n", - " [ 2.95398951e-01 4.30877179e-01 -2.03731894e+00 ... -4.20221925e-01\n", - " 3.29260826e-01 5.83679557e-01]\n", - " [ 1.30742240e+00 -6.32183790e-01 -3.13741422e+00 ... 9.63868052e-02\n", - " 2.91730791e-01 1.33400351e-01]\n", - " [ 5.43292165e-01 -2.83665359e-01 -1.88138187e+00 ... 2.15468198e-01\n", - " 4.90157723e-01 2.40562439e+00]]\n", - "\n", - " [[ 1.57632053e+00 6.27885723e+00 2.87853765e+00 ... 3.07016110e+00\n", - " 1.91490650e+00 1.76274943e+00]\n", - " [ 2.57776356e+00 4.07256317e+00 2.52231169e+00 ... 4.09494352e+00\n", - " 2.53548074e+00 2.44395185e+00]\n", - " [ 2.43037057e+00 4.35728836e+00 1.96233964e+00 ... 2.26702976e+00\n", - " 2.94634581e+00 2.21452284e+00]\n", - " ...\n", - " [-2.72509992e-01 -8.41220498e-01 -1.89133918e+00 ... -1.80079627e+00\n", - " -2.00367713e+00 -7.09145784e-01]\n", - " [ 8.21575999e-01 -1.13323164e+00 -2.62418866e+00 ... -2.38889670e+00\n", - " -7.83945560e-01 -1.01922750e-01]\n", - " [-1.14730227e+00 -1.42182577e+00 -2.00993991e+00 ... -2.11025667e+00\n", - " 1.60286129e-02 -7.26446986e-01]]\n", - "\n", - " [[ 4.20389509e+00 3.75917768e+00 4.97653627e+00 ... 1.23642838e+00\n", - " 8.52760911e-01 1.27920091e-01]\n", - " [ 5.29409122e+00 5.29002380e+00 3.96404648e+00 ... 1.91227329e+00\n", - " 3.97556186e-01 1.69182217e+00]\n", - " [ 4.60112572e+00 4.12772799e+00 2.10280085e+00 ... 3.24303842e+00\n", - " -1.07720590e+00 -3.81854475e-01]\n", - " ...\n", - " [ 1.81884170e-02 -3.11472058e+00 -8.23525012e-01 ... -2.40161085e+00\n", - " -4.48192549e+00 -6.14600539e-01]\n", - " [ 1.16305006e+00 -1.15409636e+00 -3.48765063e+00 ... -1.97504926e+00\n", - " -4.44984436e+00 -2.28429958e-01]\n", - " [ 1.29197860e+00 6.17720246e-01 -5.87171853e-01 ... -1.35258228e-01\n", - " -1.29259872e+00 1.30360842e-01]]]\n", - "\n", - "\n", - " [[[-1.26687372e+00 -2.33633637e+00 -1.49625254e+00 ... 2.52396107e+00\n", - " -6.68072224e-01 -1.13282454e+00]\n", - " [-1.34229445e+00 -2.87080932e+00 -2.57388353e+00 ... -8.75385761e-01\n", - " -1.00205469e+00 -3.58956242e+00]\n", - " [-9.49853599e-01 -5.78684711e+00 -3.52962446e+00 ... 8.88233304e-01\n", - " 2.25133196e-01 -1.02802217e+00]\n", - " ...\n", - " [-7.38113701e-01 -3.47510982e+00 -3.23011065e+00 ... -1.25624001e+00\n", - " -1.63268471e+00 6.00247443e-01]\n", - " [-2.29733467e+00 -5.72547615e-01 -1.98301303e+00 ... -1.90137398e+00\n", - " -1.47013855e+00 -1.45779204e+00]\n", - " [-2.24628520e+00 -3.36337948e+00 -3.91878939e+00 ... -1.53652275e+00\n", - " -1.36285520e+00 -1.68160331e+00]]\n", - "\n", - " [[-8.11348319e-01 -7.17824280e-01 -1.02243233e+00 ... -2.69050407e+00\n", - " -2.32403350e+00 -4.25943947e+00]\n", - " [-2.35056520e+00 -2.35941172e+00 -1.24398732e+00 ... -2.08313870e+00\n", - " -1.16508257e+00 -1.30353463e+00]\n", - " [-2.25146723e+00 -1.94972813e+00 -1.13295293e+00 ... -2.61496377e+00\n", - " -1.91106403e+00 -1.07801402e+00]\n", - " ...\n", - " [-2.67012739e+00 -3.20916414e+00 -2.41768575e+00 ... 2.65138328e-01\n", - " -5.27612507e-01 1.44604075e+00]\n", - " [-3.54237866e+00 -3.62832785e+00 -2.40270257e+00 ... -9.76106226e-02\n", - " 4.67946082e-01 -7.24248111e-01]\n", - " [-2.49844384e+00 -3.42463255e+00 -2.99040008e+00 ... 4.28889185e-01\n", - " -7.51657963e-01 -1.00530767e+00]]\n", - "\n", - " [[-8.42589438e-02 1.42022014e-01 -8.51281703e-01 ... 4.21745628e-01\n", - " -2.35717297e-02 -1.71374834e+00]\n", - " [-1.05496287e+00 3.82416457e-01 -4.40595537e-01 ... 1.03381336e-01\n", - " -1.41204190e+00 -7.58325040e-01]\n", - " [-2.28930283e+00 -2.03857040e+00 -9.16261196e-01 ... -3.94939929e-01\n", - " -1.07798588e+00 -1.48433352e+00]\n", - " ...\n", - " [-3.11473966e-01 -1.40877593e+00 -2.42908645e+00 ... 7.88682699e-01\n", - " 1.24199319e+00 1.89949930e-01]\n", - " [ 5.44084549e-01 -1.02425671e+00 -1.53991556e+00 ... -4.36764538e-01\n", - " -5.78772545e-01 2.62665659e-01]\n", - " [ 1.26812792e+00 -9.89493608e-01 -1.47972977e+00 ... 2.21440494e-02\n", - " 2.79776216e-01 7.63269484e-01]]\n", - "\n", - " ...\n", - "\n", - " [[ 6.02095068e-01 5.93243122e-01 -1.06838238e+00 ... 3.56546330e+00\n", - " 1.16390383e+00 -1.47593319e-01]\n", - " [ 1.80458140e+00 1.68401957e+00 4.17516947e-01 ... 3.33444500e+00\n", - " 1.89411759e+00 1.03220642e-01]\n", - " [ 2.74264169e+00 2.92038846e+00 1.00775683e+00 ... 3.53285050e+00\n", - " 2.07282662e+00 -2.56800652e-01]\n", - " ...\n", - " [ 4.88933468e+00 3.72433925e+00 3.58677816e+00 ... 1.98363388e+00\n", - " 1.80851030e+00 8.32634747e-01]\n", - " [ 4.01546288e+00 4.78934765e+00 2.94778132e+00 ... 2.99637699e+00\n", - " 1.30439472e+00 3.61029744e-01]\n", - " [ 3.13628030e+00 2.01894832e+00 2.82585931e+00 ... 2.54264188e+00\n", - " -9.16651785e-02 9.93353873e-02]]\n", - "\n", - " [[ 2.35585642e+00 8.42678428e-01 1.57331872e+00 ... 3.65935063e+00\n", - " 3.94066262e+00 4.89832020e+00]\n", - " [ 1.85791731e+00 1.34373701e+00 1.30812299e+00 ... 2.71434736e+00\n", - " 3.22004294e+00 2.99872303e+00]\n", - " [ 1.67675853e+00 -4.05569375e-02 1.85539150e+00 ... 3.73934364e+00\n", - " 2.98195982e+00 3.37315011e+00]\n", - " ...\n", - " [ 2.14539170e+00 2.86586595e+00 2.20222116e+00 ... 1.20492995e+00\n", - " 2.13971066e+00 1.94932449e+00]\n", - " [ 4.68422651e+00 3.80044746e+00 4.23209000e+00 ... 2.40658951e+00\n", - " 2.29117441e+00 2.52368808e+00]\n", - " [ 3.10694575e+00 2.49402595e+00 4.53786707e+00 ... 9.08902645e-01\n", - " 1.86903965e+00 2.27776885e+00]]\n", - "\n", - " [[ 1.45200038e+00 5.17961740e-01 -1.58403587e+00 ... 5.07019472e+00\n", - " 7.87163258e-01 1.20610237e+00]\n", - " [ 3.39321136e+00 2.21043849e+00 -6.31202877e-01 ... 4.97822762e+00\n", - " 9.66498017e-01 1.18883348e+00]\n", - " [ 1.20627856e+00 1.82759428e+00 5.91053367e-01 ... 4.14318657e+00\n", - " 5.25399208e-01 -1.16850233e+00]\n", - " ...\n", - " [ 1.05183899e+00 5.80030501e-01 1.89724147e+00 ... 2.54626465e+00\n", - " -1.49128008e+00 -1.85064209e+00]\n", - " [ 1.50983357e+00 2.85973406e+00 2.61224055e+00 ... 4.83481932e+00\n", - " 9.67048705e-02 -4.37043965e-01]\n", - " [ 2.57720876e+00 2.09961963e+00 4.11754288e-02 ... 3.80421424e+00\n", - " -7.83308804e-01 -1.64871216e+00]]]\n", - "\n", - "\n", - " ...\n", - "\n", - "\n", - " [[[-1.16345096e+00 -2.53971386e+00 -8.99101734e-01 ... -4.35583591e-01\n", - " -1.29671764e+00 -1.61429560e+00]\n", - " [ 3.72841507e-01 3.45808208e-01 -1.82167351e+00 ... -2.14515448e+00\n", - " -1.26383066e+00 -2.27464601e-01]\n", - " [ 1.58568513e+00 2.58181524e+00 1.86554670e+00 ... -1.10401320e+00\n", - " -3.68550658e-01 -2.58849680e-01]\n", - " ...\n", - " [-9.15827155e-01 -1.25424683e+00 -4.04716206e+00 ... 2.13138080e+00\n", - " 2.67662477e+00 2.31014514e+00]\n", - " [-3.19453120e-01 -6.71132684e-01 -1.51378751e+00 ... 1.86080432e+00\n", - " 2.77418542e+00 1.22875953e+00]\n", - " [-1.20453942e+00 -3.93669218e-01 -1.51751983e+00 ... 1.17620552e+00\n", - " 1.95602298e+00 7.64306366e-01]]\n", - "\n", - " [[-8.73186827e-01 -2.12537169e+00 -1.91664994e+00 ... -2.90821463e-01\n", - " 1.90896463e+00 8.02283168e-01]\n", - " [-1.06389821e+00 -2.15300727e+00 -1.82113051e+00 ... -4.34280694e-01\n", - " 1.53455496e+00 1.94702053e+00]\n", - " [-2.08403468e+00 -4.72900331e-01 -1.10610819e+00 ... -8.79420400e-01\n", - " 7.79394627e-01 2.02670670e+00]\n", - " ...\n", - " [-4.28208113e-01 -7.90894389e-01 -1.06713009e+00 ... 1.12579381e+00\n", - " 9.61961091e-01 1.40342009e+00]\n", - " [ 4.40416574e-01 -1.65901780e-02 -1.05338669e+00 ... 1.40698349e+00\n", - " 9.43485856e-01 2.34856772e+00]\n", - " [-1.20572495e+00 -2.03134632e+00 4.88817632e-01 ... 2.20770907e+00\n", - " 1.38143206e+00 2.00714707e+00]]\n", - "\n", - " [[ 9.00486887e-01 -9.50459957e-01 -1.42935121e+00 ... -1.30648065e+00\n", - " -2.52133775e+00 -8.87715697e-01]\n", - " [ 3.73431134e+00 1.69571114e+00 5.99429727e-01 ... 6.64332986e-01\n", - " -6.10453069e-01 2.06534386e+00]\n", - " [ 1.59800696e+00 -4.59622175e-01 -6.73136234e-01 ... 2.18770742e-01\n", - " -1.12928271e+00 4.87097502e-02]\n", - " ...\n", - " [ 1.92336845e+00 1.37130380e-01 -3.51048648e-01 ... 5.41638851e-01\n", - " 1.06069386e+00 1.36404145e+00]\n", - " [ 1.29641414e+00 -2.79530913e-01 -2.63607264e-01 ... -8.62445176e-01\n", - " 1.48393130e+00 2.69196725e+00]\n", - " [ 1.14442182e+00 -1.24098969e+00 3.70959163e-01 ... -1.12241995e+00\n", - " 3.67927134e-01 2.55976987e+00]]\n", - "\n", - " ...\n", - "\n", - " [[ 5.32017851e+00 3.64207411e+00 3.84571218e+00 ... 3.60754800e+00\n", - " 2.57500267e+00 -1.38083458e-01]\n", - " [ 5.69058084e+00 3.93056583e+00 2.93337941e+00 ... 3.17091584e+00\n", - " 2.34770632e+00 6.48133337e-01]\n", - " [ 5.98239613e+00 6.16548634e+00 3.04750896e+00 ... 5.51510525e+00\n", - " 4.34810448e+00 1.31588542e+00]\n", - " ...\n", - " [ 5.09930992e+00 3.32360983e+00 2.29228449e+00 ... 3.45123887e-01\n", - " 1.06280947e+00 -5.93325794e-02]\n", - " [ 4.19760656e+00 3.97779059e+00 1.66905916e+00 ... 3.68937254e-01\n", - " 8.06131065e-02 8.08142900e-01]\n", - " [ 4.52498960e+00 3.45109749e+00 1.01074433e+00 ... -2.54036248e-01\n", - " 3.13675582e-01 2.13851762e+00]]\n", - "\n", - " [[ 6.93927193e+00 6.05758238e+00 4.60648441e+00 ... 4.32221603e+00\n", - " 3.17874146e+00 1.47012353e+00]\n", - " [ 7.88523865e+00 6.62228966e+00 4.77496338e+00 ... 4.45868683e+00\n", - " 2.73698759e+00 2.17057824e+00]\n", - " [ 7.12061214e+00 6.01714134e+00 4.52996492e+00 ... 3.97184372e+00\n", - " 3.43153954e+00 1.21802723e+00]\n", - " ...\n", - " [ 2.85720730e+00 1.89639473e+00 1.96340394e+00 ... 1.89643729e+00\n", - " 1.64856291e+00 1.15853786e+00]\n", - " [ 3.88248491e+00 2.16386199e+00 1.53069091e+00 ... 2.71704245e+00\n", - " 2.24890351e+00 2.22156644e+00]\n", - " [ 5.27136230e+00 1.68400204e+00 2.09500480e+00 ... 2.75956345e+00\n", - " 3.71970820e+00 1.69852686e+00]]\n", - "\n", - " [[ 2.55598164e+00 1.64588141e+00 6.70431674e-01 ... 3.24091220e+00\n", - " 1.48759770e+00 -1.72001183e+00]\n", - " [ 4.33942318e+00 8.40826690e-01 -7.40000725e-01 ... 7.24577069e-01\n", - " 1.74327165e-01 -1.83029580e+00]\n", - " [ 4.39864540e+00 2.28395438e+00 -1.90353513e-01 ... 5.58019161e+00\n", - " 1.05627227e+00 -8.02519619e-01]\n", - " ...\n", - " [ 1.97654784e+00 3.26888156e+00 1.52879453e+00 ... 3.15013933e+00\n", - " 4.66731453e+00 4.98701715e+00]\n", - " [ 1.40016854e+00 3.45761251e+00 3.68359756e+00 ... 1.14207900e+00\n", - " 3.32219076e+00 3.83035636e+00]\n", - " [ 1.99269783e+00 2.15428829e+00 3.35396528e-01 ... 2.45916694e-01\n", - " 2.13785577e+00 4.33214951e+00]]]\n", - "\n", - "\n", - " [[[ 1.35320330e+00 5.05850911e-02 1.04915988e+00 ... 1.82023585e-01\n", - " 2.72914767e-01 3.92112255e-01]\n", - " [ 1.04646444e+00 7.60913491e-01 1.93323612e+00 ... 1.19493449e+00\n", - " -1.44200325e-01 4.07531261e-02]\n", - " [-9.88207340e-01 -1.46165287e+00 1.05884135e-01 ... -3.23057353e-01\n", - " -2.28934169e+00 -7.38609374e-01]\n", - " ...\n", - " [ 1.01198792e+00 2.34331083e+00 1.04566610e+00 ... 1.29697472e-01\n", - " -1.23878837e+00 2.21006930e-01]\n", - " [-3.75360101e-01 1.53673506e+00 -1.32206869e+00 ... -2.55255580e-01\n", - " -6.22699618e-01 -1.73162484e+00]\n", - " [ 4.34735864e-01 5.08327007e-01 -3.49233925e-01 ... -1.04749084e+00\n", - " -1.15777385e+00 -1.13671994e+00]]\n", - "\n", - " [[ 1.67839336e+00 -1.80224836e-01 1.02194118e+00 ... 8.44027162e-01\n", - " 8.81283879e-02 -1.37762165e+00]\n", - " [ 8.39694083e-01 1.32322550e+00 4.02442753e-01 ... -4.21785116e-01\n", - " -9.98012185e-01 -1.11348581e+00]\n", - " [ 7.64424682e-01 8.58965695e-01 2.94626594e-01 ... -6.65519595e-01\n", - " -3.65677416e-01 -2.25250268e+00]\n", - " ...\n", - " [-1.10193872e+00 1.18070498e-01 1.04604781e-01 ... -1.44486964e+00\n", - " -2.52748466e+00 -2.16131711e+00]\n", - " [-1.06079710e+00 -1.48379254e+00 3.80138367e-01 ... -1.62288392e+00\n", - " -2.44736362e+00 -8.78590107e-01]\n", - " [ 3.44401300e-02 -2.60935068e+00 -2.35597759e-01 ... -2.41114974e+00\n", - " -2.45255780e+00 -1.82384634e+00]]\n", - "\n", - " [[ 1.37670958e+00 1.58661580e+00 -2.85664916e-01 ... 1.49081087e+00\n", - " 4.13422853e-01 1.12761199e+00]\n", - " [ 1.54148173e+00 6.22704089e-01 1.41886568e+00 ... 1.59678531e+00\n", - " -8.72656107e-01 1.52415514e-01]\n", - " [ 3.30207205e+00 2.89925170e+00 1.91855145e+00 ... 3.18863559e+00\n", - " 1.87347198e+00 9.48901057e-01]\n", - " ...\n", - " [-1.53920484e+00 1.77375078e-02 -1.02018684e-01 ... 1.94011092e+00\n", - " -6.83587790e-01 1.49154460e+00]\n", - " [-2.27719522e+00 1.02481163e+00 -2.11300224e-01 ... -8.18020821e-01\n", - " 1.54248989e+00 -1.46732473e+00]\n", - " [-4.50206220e-01 3.62383485e+00 1.07175660e+00 ... 4.25961137e-01\n", - " 1.12405360e-01 -6.87821358e-02]]\n", - "\n", - " ...\n", - "\n", - " [[-3.40477467e-01 -2.99311423e+00 -2.12096786e+00 ... 2.27393007e+00\n", - " 4.03424358e+00 3.73335361e+00]\n", - " [-6.99971199e-01 -2.97719741e+00 -2.72910309e+00 ... 1.50101089e+00\n", - " 2.29408574e+00 3.14105940e+00]\n", - " [-1.41648722e+00 -1.86292887e+00 -1.84006739e+00 ... 2.78402638e+00\n", - " 3.91481900e+00 5.32456112e+00]\n", - " ...\n", - " [ 5.97958088e-01 1.50512588e+00 6.23718500e-01 ... 2.83813477e+00\n", - " 3.87909842e+00 3.33359623e+00]\n", - " [ 1.65542316e+00 3.56163192e+00 4.01527691e+00 ... 3.38367462e+00\n", - " 1.55827272e+00 2.50741863e+00]\n", - " [ 2.82036042e+00 2.53322673e+00 4.38798475e+00 ... 4.64642382e+00\n", - " 3.28739667e+00 3.02895570e+00]]\n", - "\n", - " [[-3.47941303e+00 -3.49006844e+00 -2.25583363e+00 ... 1.45181656e-01\n", - " 1.52944064e+00 2.08810711e+00]\n", - " [-2.27786446e+00 -4.59218550e+00 -2.74722624e+00 ... -1.73136210e+00\n", - " 7.46028006e-01 1.74789345e+00]\n", - " [-3.35524082e+00 -4.58244705e+00 -2.40820456e+00 ... -5.04051924e-01\n", - " 1.49640536e+00 2.16613841e+00]\n", - " ...\n", - " [ 5.26107132e-01 2.05329061e+00 2.84252572e+00 ... 1.33222675e+00\n", - " 3.87935114e+00 3.69385266e+00]\n", - " [ 4.38092083e-01 2.15028906e+00 3.13363624e+00 ... 3.36048746e+00\n", - " 5.36551809e+00 2.94915986e+00]\n", - " [ 2.75497317e+00 3.25929213e+00 2.33522987e+00 ... 1.69926262e+00\n", - " 3.93462896e+00 3.68200874e+00]]\n", - "\n", - " [[ 1.10951948e+00 5.31419516e-02 -1.58864903e+00 ... 5.24887085e+00\n", - " 1.60273385e+00 4.90113163e+00]\n", - " [-2.94517064e+00 -2.81092644e+00 -4.89631557e+00 ... 3.99868512e+00\n", - " 1.40544355e+00 2.84833241e+00]\n", - " [-3.51893663e-01 -3.53325534e+00 -2.21239805e+00 ... 4.26225853e+00\n", - " 6.87886119e-01 2.58609629e+00]\n", - " ...\n", - " [ 2.92248201e+00 5.40264511e+00 4.65721560e+00 ... 5.24537373e+00\n", - " 2.30406880e+00 1.29892707e+00]\n", - " [ 1.43473256e+00 4.61167526e+00 3.57578802e+00 ... 5.12181854e+00\n", - " 8.59923482e-01 1.38731599e+00]\n", - " [-6.50881350e-01 2.18233657e+00 2.74669623e+00 ... 4.86368895e+00\n", - " 1.44120216e+00 1.79993320e+00]]]\n", - "\n", - "\n", - " [[[ 1.64106202e+00 3.54410499e-01 -3.54172409e-01 ... 2.32646990e+00\n", - " 1.65043330e+00 3.45897645e-01]\n", - " [ 2.16236949e+00 1.28213906e+00 2.26082468e+00 ... 6.10507369e-01\n", - " 9.12241280e-01 1.27429694e-01]\n", - " [ 2.07962990e+00 7.03816175e-01 2.01272345e+00 ... -2.26959705e-01\n", - " 1.00041127e+00 5.87104559e-02]\n", - " ...\n", - " [-1.62972426e+00 -3.04028845e+00 -1.39124167e+00 ... 2.47561097e+00\n", - " 2.35047388e+00 1.61532843e+00]\n", - " [-1.97368932e+00 -5.44541061e-01 -5.92882216e-01 ... 1.39800012e+00\n", - " 2.32770801e+00 9.96662021e-01]\n", - " [-1.15636075e+00 -1.34654212e+00 -8.50648999e-01 ... 1.85655832e+00\n", - " 2.05776072e+00 5.34575820e-01]]\n", - "\n", - " [[-1.02104437e+00 3.08469892e-01 2.81789303e-01 ... -8.24654043e-01\n", - " -9.85817850e-01 -2.05517030e+00]\n", - " [ 9.50192690e-01 3.35105330e-01 5.31637192e-01 ... -1.42974198e-01\n", - " -1.79659498e+00 -1.58266973e+00]\n", - " [-2.51316994e-01 -1.28709340e+00 3.01498562e-01 ... -1.32253516e+00\n", - " -1.55507576e+00 -9.37123299e-01]\n", - " ...\n", - " [ 2.33016998e-01 2.92454743e+00 3.15420461e+00 ... 1.15574491e+00\n", - " 1.27850962e+00 1.35487700e+00]\n", - " [ 3.81013602e-01 1.44239831e+00 6.64825320e-01 ... -3.89374971e-01\n", - " 1.50716826e-01 1.33641326e+00]\n", - " [ 1.71373415e+00 1.67357373e+00 1.76596940e+00 ... 1.57941079e+00\n", - " 1.60940981e+00 1.78091609e+00]]\n", - "\n", - " [[-5.16522598e+00 -1.68099070e+00 -3.24440050e+00 ... -3.46229005e+00\n", - " -2.18273020e+00 -1.98621082e+00]\n", - " [-3.05743694e+00 9.15392339e-01 -1.93508530e+00 ... -1.82306373e+00\n", - " -2.12960863e+00 -3.45255351e+00]\n", - " [-4.32777822e-01 -1.00303245e+00 -1.61397791e+00 ... -2.08376765e+00\n", - " -3.72989595e-01 -1.36516929e+00]\n", - " ...\n", - " [-5.83641946e-01 4.14125490e+00 1.58227599e+00 ... 2.03144050e+00\n", - " 2.13982654e+00 -1.81909311e+00]\n", - " [-1.74230576e+00 2.39347410e+00 2.44080925e+00 ... 5.43732524e-01\n", - " 2.07899213e+00 -3.71748984e-01]\n", - " [ 3.80016506e-01 7.84988403e-01 1.20596504e+00 ... -2.32057095e+00\n", - " -2.81265080e-01 -3.69353056e+00]]\n", - "\n", - " ...\n", - "\n", - " [[-3.48024845e+00 -2.60937548e+00 -3.84952760e+00 ... 6.68736577e-01\n", - " -1.75104141e-02 -3.54720926e+00]\n", - " [-2.59637117e+00 -5.18190145e+00 -2.33887696e+00 ... 9.13373232e-02\n", - " -3.58282638e+00 -2.40778995e+00]\n", - " [-2.50912881e+00 -1.22113395e+00 -2.34372020e+00 ... 1.40071487e+00\n", - " -1.67449510e+00 -1.14655948e+00]\n", - " ...\n", - " [-5.75253534e+00 -6.67348385e+00 -5.05184650e+00 ... -2.73145151e+00\n", - " -1.48933101e+00 -1.36807609e+00]\n", - " [-3.29049587e+00 -3.73956156e+00 -2.85064268e+00 ... -3.92481357e-01\n", - " -8.00529659e-01 -8.39800835e-01]\n", - " [-4.30351114e+00 -4.21471930e+00 -2.41703367e+00 ... -1.27081513e+00\n", - " 1.67839837e+00 8.47821474e-01]]\n", - "\n", - " [[-5.27856112e-01 -1.09752083e+00 3.39107156e-01 ... 2.00062895e+00\n", - " 8.83528054e-01 2.57416844e-01]\n", - " [-1.58655810e+00 -3.36268663e-01 1.16161990e+00 ... 1.54868484e+00\n", - " 2.38878536e+00 1.84097290e+00]\n", - " [ 5.96052647e-01 2.15484858e-01 1.85280466e+00 ... 2.74587560e+00\n", - " 1.61432290e+00 1.13214278e+00]\n", - " ...\n", - " [-4.57659864e+00 -5.42679739e+00 -4.35204458e+00 ... -1.82452416e+00\n", - " -2.18670201e+00 -3.91811800e+00]\n", - " [-1.32477629e+00 -4.19110394e+00 -3.41308069e+00 ... 1.39622003e-01\n", - " -1.59393203e+00 -9.08105671e-01]\n", - " [-3.60161018e+00 -4.05932713e+00 -2.23674798e+00 ... 9.09647286e-01\n", - " 9.73127842e-01 1.19991803e+00]]\n", - "\n", - " [[ 2.04062796e+00 7.95603275e-01 -1.28833270e+00 ... 4.64749050e+00\n", - " 2.25974560e+00 1.02396965e+00]\n", - " [ 1.68882537e+00 2.63353348e+00 2.53597498e-02 ... 4.69063854e+00\n", - " -4.19382691e-01 2.91669458e-01]\n", - " [ 7.71395087e-01 1.20833695e+00 -2.58601785e-01 ... 1.21794045e+00\n", - " -1.51922226e-01 7.44265199e-01]\n", - " ...\n", - " [-6.66095781e+00 -4.81577682e+00 -5.39921665e+00 ... -2.20548606e+00\n", - " 5.72486281e-01 -4.35207397e-01]\n", - " [-7.51608658e+00 -6.67776871e+00 -3.73199415e+00 ... -1.70327055e+00\n", - " 1.01334639e-02 -3.20627165e+00]\n", - " [-5.73050356e+00 -2.74379373e+00 -3.70248461e+00 ... -1.09794116e+00\n", - " -1.73590891e-02 -1.80156028e+00]]]]\n", - "param grad: conv.conv_stack.0.conv.bias: shape: [32] stop_grad: False grad: [-1.4305115e-06 0.0000000e+00 -4.0531158e-06 -1.6689301e-06\n", - " 2.3841858e-07 -7.1525574e-07 1.1920929e-06 1.5497208e-06\n", - " -2.3841858e-07 1.6689301e-06 9.5367432e-07 9.5367432e-07\n", - " -2.6226044e-06 1.1920929e-06 1.3113022e-06 1.9669533e-06\n", - " -4.7683716e-07 1.1920929e-06 -1.6689301e-06 -1.5497208e-06\n", - " -2.2649765e-06 4.7683716e-07 2.3841858e-06 -3.5762787e-06\n", - " 2.3841858e-07 2.1457672e-06 -3.5762787e-07 8.3446503e-07\n", - " -3.5762787e-07 -7.1525574e-07 2.6524067e-06 -1.1920929e-06]\n", - "param grad: conv.conv_stack.0.bn.weight: shape: [32] stop_grad: False grad: [-3.7669735 1.5226867 1.759756 4.501629 -2.2077336 0.18411277\n", - " 1.3558264 -1.0269645 3.9628277 3.9300344 -2.80754 1.8462183\n", - " -0.03385968 2.1284049 0.46124816 -4.364863 0.78491163 0.25565645\n", - " -5.3538237 3.2606194 0.79100513 -1.4652673 2.769378 1.2283417\n", - " -4.7466464 -1.3404545 -6.9374166 0.710248 2.0944448 0.4334769\n", - " -0.24313992 0.31392363]\n", - "param grad: conv.conv_stack.0.bn.bias: shape: [32] stop_grad: False grad: [-0.6251638 2.833331 0.6993131 3.7106915 -2.262496 0.7390424\n", - " 0.5360477 -2.803875 2.1646228 2.117193 -1.9988279 1.5135905\n", - " -2.0181084 2.6450465 0.06302822 -3.0530102 1.4788482 0.5941844\n", - " -3.1690063 1.8753575 -0.0737313 -2.7806277 -0.04483938 0.16129279\n", - " -1.2960215 -0.38020235 -0.55218065 0.10754502 2.065371 -1.4703183\n", - " -0.40964937 -1.4454535 ]\n", - "param grad: conv.conv_stack.0.bn._mean: shape: [32] stop_grad: True grad: None\n", - "param grad: conv.conv_stack.0.bn._variance: shape: [32] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_fc.weight: shape: [1312, 1024] stop_grad: False grad: [[-0.46178514 0.1095643 0.06441769 ... 0.42020613 -0.34181893\n", - " -0.0658682 ]\n", - " [-0.03619978 0.21653323 0.01727325 ... 0.05731536 -0.37822944\n", - " -0.05464617]\n", - " [-0.32397318 0.04158126 -0.08091418 ... 0.0928297 -0.06518176\n", - " -0.40110156]\n", - " ...\n", - " [-0.2702023 0.05126935 0.11825457 ... 0.0069707 -0.36951366\n", - " 0.37071258]\n", - " [-0.11326203 0.19305304 -0.133317 ... -0.13030824 -0.09068564\n", - " 0.32735693]\n", - " [-0.04543798 0.09902512 -0.10745425 ... -0.06685166 -0.3055201\n", - " 0.0752247 ]]\n", - "param grad: rnn.rnn_stacks.0.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.07338604 0.64991236 0.5465856 ... 0.507725 0.14061031\n", - " 0.3020359 ]\n", - "param grad: rnn.rnn_stacks.0.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.41395143 -0.28493872 0.36796764 ... 0.2387953 0.06732331\n", - " 0.16263628]\n", - "param grad: rnn.rnn_stacks.0.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.0.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.09370177 -0.12264141 -0.08237482 ... -0.50241685 -0.149155\n", - " -0.25661892]\n", - " [-0.37426725 0.44987115 0.10685667 ... -0.65946174 -0.4499248\n", - " -0.17545304]\n", - " [-0.03753807 0.33422717 0.12750985 ... 0.05405155 -0.17648363\n", - " 0.05315325]\n", - " ...\n", - " [ 0.15721183 0.03064088 -0.00751081 ... 0.27183983 0.3881693\n", - " -0.01544908]\n", - " [ 0.26047793 0.16917065 0.00915196 ... 0.18076143 -0.05080506\n", - " 0.14791614]\n", - " [ 0.19052255 0.03642382 -0.14313167 ... 0.2611448 0.20763844\n", - " 0.26846847]]\n", - "param grad: rnn.rnn_stacks.0.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.4139514 -0.28493875 0.36796758 ... 0.23879525 0.06732336\n", - " 0.16263627]\n", - "param grad: rnn.rnn_stacks.0.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.0.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[ 0.04214853 -0.1710323 0.17557406 ... 0.11926915 0.21577051\n", - " -0.30598596]\n", - " [-0.02370887 -0.03498494 -0.05991999 ... -0.06049232 -0.14527473\n", - " -0.5335691 ]\n", - " [-0.21417995 -0.10263194 -0.05903128 ... -0.26958284 0.05936668\n", - " 0.25522667]\n", - " ...\n", - " [ 0.31594425 -0.29487017 0.15871571 ... 0.3504135 -0.1418606\n", - " -0.07482046]\n", - " [ 0.22316164 0.7682122 -0.22191924 ... -0.00535548 -0.6497105\n", - " -0.2011079 ]\n", - " [-0.05800886 0.13750821 0.02450509 ... 0.245736 0.07425706\n", - " -0.17761081]]\n", - "param grad: rnn.rnn_stacks.1.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.45080703 0.19005743 0.077441 ... -0.24504453 0.19666554\n", - " -0.10503208]\n", - "param grad: rnn.rnn_stacks.1.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237206 0.03389215 ... -0.35602498 0.25528812\n", - " 0.11344345]\n", - "param grad: rnn.rnn_stacks.1.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.1.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.48457903 0.04466334 -0.19785863 ... -0.0254025 -0.10338341\n", - " -0.29202533]\n", - " [-0.15261276 0.00412052 0.22198747 ... 0.22460426 -0.03752084\n", - " 0.05170784]\n", - " [-0.09337254 0.02530848 0.1263681 ... -0.02056236 0.33342454\n", - " -0.08760723]\n", - " ...\n", - " [-0.28645608 -0.19169135 -0.1361257 ... -0.00444204 -0.06552711\n", - " -0.14726155]\n", - " [ 0.21883707 0.2049045 0.23723911 ... 0.4626113 -0.14110637\n", - " 0.02569831]\n", - " [ 0.37554163 -0.19249167 0.14591683 ... 0.25602737 0.40088275\n", - " 0.41056633]]\n", - "param grad: rnn.rnn_stacks.1.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.55867654 0.04237211 0.0338921 ... -0.35602498 0.2552881\n", - " 0.11344352]\n", - "param grad: rnn.rnn_stacks.1.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.1.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_fc.weight: shape: [2048, 1024] stop_grad: False grad: [[-0.28007814 -0.09206 -0.01297755 ... -0.2557205 -0.2693453\n", - " 0.05862035]\n", - " [-0.34194735 -0.01383794 -0.06490533 ... -0.11063005 0.16226721\n", - " -0.3197178 ]\n", - " [-0.3646778 0.15443833 0.02241019 ... -0.15093157 -0.09886418\n", - " -0.44295847]\n", - " ...\n", - " [-0.01041886 -0.57636976 -0.03988511 ... -0.2260822 0.49646813\n", - " -0.15528557]\n", - " [-0.19385241 -0.56451964 -0.05551083 ... -0.5638106 0.43611372\n", - " -0.61484563]\n", - " [ 0.1051331 -0.4762463 0.11194798 ... -0.26766616 -0.30734932\n", - " 0.17856634]]\n", - "param grad: rnn.rnn_stacks.2.fw_bn.weight: shape: [1024] stop_grad: False grad: [-0.02791309 -0.992517 0.63012564 ... -1.1830902 1.4646478\n", - " 1.6333911 ]\n", - "param grad: rnn.rnn_stacks.2.fw_bn.bias: shape: [1024] stop_grad: False grad: [-0.10834587 -1.7079136 0.81259465 ... -1.4478713 1.455745\n", - " 2.069446 ]\n", - "param grad: rnn.rnn_stacks.2.fw_bn._mean: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_bn._variance: shape: [1024] stop_grad: True grad: None\n", - "param grad: rnn.rnn_stacks.2.fw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: [[-0.14363798 -0.06933184 0.02901152 ... -0.19233373 -0.03206367\n", - " -0.00845779]\n", - " [-0.44314507 -0.8921327 -1.031872 ... -0.558997 -0.53070104\n", - " -0.855925 ]\n", - " [ 0.15673254 0.28793585 0.13351494 ... 0.38433537 0.5040767\n", - " 0.11303265]\n", - " ...\n", - " [-0.22923109 -0.62508404 -0.6195032 ... -0.6876448 -0.41718128\n", - " -0.74844164]\n", - " [ 0.18024652 0.45618314 0.81391454 ... 0.5780604 0.87566674\n", - " 0.71526295]\n", - " [ 0.3763076 0.54033077 0.9940485 ... 1.087821 0.72288674\n", - " 1.2852117 ]]\n", - "param grad: rnn.rnn_stacks.2.fw_cell.bias_hh: shape: [1024] stop_grad: False grad: [-0.10834593 -1.7079139 0.8125948 ... -1.4478711 1.4557447\n", - " 2.0694466 ]\n", - "param grad: rnn.rnn_stacks.2.bw_cell.weight_hh: shape: [1024, 1024] stop_grad: False grad: None\n", - "param grad: rnn.rnn_stacks.2.bw_cell.bias_hh: shape: [1024] stop_grad: False grad: None\n", - "param grad: fc.weight: shape: [2048, 4299] stop_grad: False grad: [[ 1.4382483e-02 2.0160766e-02 1.2322801e-02 ... 1.0075266e-02\n", - " 7.4421698e-03 -2.3925617e+01]\n", - " [ 3.7887424e-02 5.7105277e-02 2.8803380e-02 ... 2.4820438e-02\n", - " 1.8560058e-02 -5.0687141e+01]\n", - " [ 4.5566272e-02 5.4415584e-02 3.2858539e-02 ... 3.2725763e-02\n", - " 2.1536341e-02 -6.1036335e+01]\n", - " ...\n", - " [ 2.8015019e-02 3.5967816e-02 2.3228688e-02 ... 2.1284629e-02\n", - " 1.3860047e-02 -5.2543671e+01]\n", - " [ 2.8445240e-02 4.2448867e-02 2.7125146e-02 ... 2.2253662e-02\n", - " 1.7470375e-02 -4.3619675e+01]\n", - " [ 4.7438074e-02 5.8287360e-02 3.4546286e-02 ... 3.0827176e-02\n", - " 2.2168703e-02 -6.7901680e+01]]\n", - "param grad: fc.bias: shape: [4299] stop_grad: False grad: [ 8.8967547e-02 1.0697905e-01 6.5251388e-02 ... 6.1503030e-02\n", - " 4.3404289e-02 -1.3512518e+02]\n" - ] - } - ], - "source": [ - "loss.backward(retain_graph=False)\n", - "for n, p in dp_model.named_parameters():\n", - " print(\n", - " f\"param grad: {n}: shape: {p.shape} stop_grad: {p.stop_gradient} grad: {p.grad}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "selected-crazy", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1.]\n" - ] - } - ], - "source": [ - "print(loss.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bottom-engineer", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "stuffed-yeast", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file diff --git a/.notebook/u2_confermer_model_wenet.ipynb b/.notebook/u2_confermer_model_wenet.ipynb deleted file mode 100644 index a425e16cb..000000000 --- a/.notebook/u2_confermer_model_wenet.ipynb +++ /dev/null @@ -1,4608 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "choice-grade", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/DeepSpeech-2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "broke-broad", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "register user softmax to paddle, remove this when fixed!\n", - "register user log_softmax to paddle, remove this when fixed!\n", - "register user sigmoid to paddle, remove this when fixed!\n", - "register user log_sigmoid to paddle, remove this when fixed!\n", - "register user relu to paddle, remove this when fixed!\n", - "override cat of paddle if exists or register, remove this when fixed!\n", - "override item of paddle.Tensor if exists or register, remove this when fixed!\n", - "override long of paddle.Tensor if exists or register, remove this when fixed!\n", - "override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle if exists or register, remove this when fixed!\n", - "override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "register user view to paddle.Tensor, remove this when fixed!\n", - "register user view_as to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "register user fill_ to paddle.Tensor, remove this when fixed!\n", - "register user repeat to paddle.Tensor, remove this when fixed!\n", - "register user softmax to paddle.Tensor, remove this when fixed!\n", - "register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "register user relu to paddle.Tensor, remove this when fixed!\n", - "register user type_as to paddle.Tensor, remove this when fixed!\n", - "register user to to paddle.Tensor, remove this when fixed!\n", - "register user float to paddle.Tensor, remove this when fixed!\n", - "register user tolist to paddle.Tensor, remove this when fixed!\n", - "register user glu to paddle.nn.functional, remove this when fixed!\n", - "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "register user Module to paddle.nn, remove this when fixed!\n", - "register user ModuleList to paddle.nn, remove this when fixed!\n", - "register user GLU to paddle.nn, remove this when fixed!\n", - "register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "register user export to paddle.jit, remove this when fixed!\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "import paddle\n", - "from yacs.config import CfgNode as CN\n", - "\n", - "from deepspeech.models.u2 import U2Model\n", - "from deepspeech.utils.layer_tools import print_params\n", - "from deepspeech.utils.layer_tools import summary" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "permanent-summary", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "[INFO 2021/04/20 03:32:21 u2.py:834] U2 Encoder type: conformer\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", - "encoder.embed.conv.0.bias | [256] | 256 | True\n", - "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", - "encoder.embed.conv.2.bias | [256] | 256 | True\n", - "encoder.embed.out.0.weight | [4864, 256] | 1245184 | True\n", - "encoder.embed.out.0.bias | [256] | 256 | True\n", - "encoder.after_norm.weight | [256] | 256 | True\n", - "encoder.after_norm.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.0.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.0.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.1.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.1.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.2.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.2.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.3.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.3.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.4.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.4.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.5.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.5.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.6.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.6.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.7.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.7.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.8.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.8.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.9.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.9.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.10.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.10.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.norm.weight | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.norm.bias | [256] | 256 | True\n", - "encoder.encoders.11.conv_module.norm._mean | [256] | 256 | False\n", - "encoder.encoders.11.conv_module.norm._variance | [256] | 256 | False\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_mha.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_mha.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_conv.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_conv.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm_final.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm_final.bias | [256] | 256 | True\n", - "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n", - "decoder.embed.0.weight | [4233, 256] | 1083648 | True\n", - "decoder.after_norm.weight | [256] | 256 | True\n", - "decoder.after_norm.bias | [256] | 256 | True\n", - "decoder.output_layer.weight | [256, 4233] | 1083648 | True\n", - "decoder.output_layer.bias | [4233] | 4233 | True\n", - "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n", - "ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n", - "ctc.ctc_lo.bias | [4233] | 4233 | True\n", - "Total parameters: 687.0, 49355282.0 elements.\n" - ] - } - ], - "source": [ - "conf_str='examples/aishell/s1/conf/conformer.yaml'\n", - "cfg = CN().load_cfg(open(conf_str))\n", - "cfg.model.input_dim = 80\n", - "cfg.model.output_dim = 4233\n", - "cfg.model.cmvn_file = \"/workspace/wenet/examples/aishell/s0/raw_wav/train/global_cmvn\"\n", - "cfg.model.cmvn_file_type = 'json'\n", - "cfg.freeze()\n", - "\n", - "model = U2Model(cfg.model)\n", - "print_params(model)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "sapphire-agent", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.global_cmvn.mean | [80] | 80\n", - "encoder.global_cmvn.istd | [80] | 80\n", - "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304\n", - "encoder.embed.conv.0.bias | [256] | 256\n", - "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824\n", - "encoder.embed.conv.2.bias | [256] | 256\n", - "encoder.embed.out.0.weight | [4864, 256] | 1245184\n", - "encoder.embed.out.0.bias | [256] | 256\n", - "encoder.after_norm.weight | [256] | 256\n", - "encoder.after_norm.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.0.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.0.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.0.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.0.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.0.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.0.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.0.norm_ff.weight | [256] | 256\n", - "encoder.encoders.0.norm_ff.bias | [256] | 256\n", - "encoder.encoders.0.norm_mha.weight | [256] | 256\n", - "encoder.encoders.0.norm_mha.bias | [256] | 256\n", - "encoder.encoders.0.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.0.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.0.norm_conv.weight | [256] | 256\n", - "encoder.encoders.0.norm_conv.bias | [256] | 256\n", - "encoder.encoders.0.norm_final.weight | [256] | 256\n", - "encoder.encoders.0.norm_final.bias | [256] | 256\n", - "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.0.concat_linear.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.1.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.1.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.1.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.1.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.1.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.1.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.1.norm_ff.weight | [256] | 256\n", - "encoder.encoders.1.norm_ff.bias | [256] | 256\n", - "encoder.encoders.1.norm_mha.weight | [256] | 256\n", - "encoder.encoders.1.norm_mha.bias | [256] | 256\n", - "encoder.encoders.1.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.1.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.1.norm_conv.weight | [256] | 256\n", - "encoder.encoders.1.norm_conv.bias | [256] | 256\n", - "encoder.encoders.1.norm_final.weight | [256] | 256\n", - "encoder.encoders.1.norm_final.bias | [256] | 256\n", - "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.1.concat_linear.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.2.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.2.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.2.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.2.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.2.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.2.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.2.norm_ff.weight | [256] | 256\n", - "encoder.encoders.2.norm_ff.bias | [256] | 256\n", - "encoder.encoders.2.norm_mha.weight | [256] | 256\n", - "encoder.encoders.2.norm_mha.bias | [256] | 256\n", - "encoder.encoders.2.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.2.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.2.norm_conv.weight | [256] | 256\n", - "encoder.encoders.2.norm_conv.bias | [256] | 256\n", - "encoder.encoders.2.norm_final.weight | [256] | 256\n", - "encoder.encoders.2.norm_final.bias | [256] | 256\n", - "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.2.concat_linear.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.3.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.3.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.3.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.3.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.3.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.3.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.3.norm_ff.weight | [256] | 256\n", - "encoder.encoders.3.norm_ff.bias | [256] | 256\n", - "encoder.encoders.3.norm_mha.weight | [256] | 256\n", - "encoder.encoders.3.norm_mha.bias | [256] | 256\n", - "encoder.encoders.3.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.3.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.3.norm_conv.weight | [256] | 256\n", - "encoder.encoders.3.norm_conv.bias | [256] | 256\n", - "encoder.encoders.3.norm_final.weight | [256] | 256\n", - "encoder.encoders.3.norm_final.bias | [256] | 256\n", - "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.3.concat_linear.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.4.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.4.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.4.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.4.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.4.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.4.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.4.norm_ff.weight | [256] | 256\n", - "encoder.encoders.4.norm_ff.bias | [256] | 256\n", - "encoder.encoders.4.norm_mha.weight | [256] | 256\n", - "encoder.encoders.4.norm_mha.bias | [256] | 256\n", - "encoder.encoders.4.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.4.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.4.norm_conv.weight | [256] | 256\n", - "encoder.encoders.4.norm_conv.bias | [256] | 256\n", - "encoder.encoders.4.norm_final.weight | [256] | 256\n", - "encoder.encoders.4.norm_final.bias | [256] | 256\n", - "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.4.concat_linear.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.5.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.5.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.5.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.5.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.5.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.5.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.5.norm_ff.weight | [256] | 256\n", - "encoder.encoders.5.norm_ff.bias | [256] | 256\n", - "encoder.encoders.5.norm_mha.weight | [256] | 256\n", - "encoder.encoders.5.norm_mha.bias | [256] | 256\n", - "encoder.encoders.5.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.5.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.5.norm_conv.weight | [256] | 256\n", - "encoder.encoders.5.norm_conv.bias | [256] | 256\n", - "encoder.encoders.5.norm_final.weight | [256] | 256\n", - "encoder.encoders.5.norm_final.bias | [256] | 256\n", - "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.5.concat_linear.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.6.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.6.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.6.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.6.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.6.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.6.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.6.norm_ff.weight | [256] | 256\n", - "encoder.encoders.6.norm_ff.bias | [256] | 256\n", - "encoder.encoders.6.norm_mha.weight | [256] | 256\n", - "encoder.encoders.6.norm_mha.bias | [256] | 256\n", - "encoder.encoders.6.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.6.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.6.norm_conv.weight | [256] | 256\n", - "encoder.encoders.6.norm_conv.bias | [256] | 256\n", - "encoder.encoders.6.norm_final.weight | [256] | 256\n", - "encoder.encoders.6.norm_final.bias | [256] | 256\n", - "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.6.concat_linear.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.7.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.7.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.7.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.7.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.7.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.7.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.7.norm_ff.weight | [256] | 256\n", - "encoder.encoders.7.norm_ff.bias | [256] | 256\n", - "encoder.encoders.7.norm_mha.weight | [256] | 256\n", - "encoder.encoders.7.norm_mha.bias | [256] | 256\n", - "encoder.encoders.7.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.7.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.7.norm_conv.weight | [256] | 256\n", - "encoder.encoders.7.norm_conv.bias | [256] | 256\n", - "encoder.encoders.7.norm_final.weight | [256] | 256\n", - "encoder.encoders.7.norm_final.bias | [256] | 256\n", - "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.7.concat_linear.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.8.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.8.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.8.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.8.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.8.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.8.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.8.norm_ff.weight | [256] | 256\n", - "encoder.encoders.8.norm_ff.bias | [256] | 256\n", - "encoder.encoders.8.norm_mha.weight | [256] | 256\n", - "encoder.encoders.8.norm_mha.bias | [256] | 256\n", - "encoder.encoders.8.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.8.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.8.norm_conv.weight | [256] | 256\n", - "encoder.encoders.8.norm_conv.bias | [256] | 256\n", - "encoder.encoders.8.norm_final.weight | [256] | 256\n", - "encoder.encoders.8.norm_final.bias | [256] | 256\n", - "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.8.concat_linear.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.9.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.9.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.9.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.9.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.9.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.9.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.9.norm_ff.weight | [256] | 256\n", - "encoder.encoders.9.norm_ff.bias | [256] | 256\n", - "encoder.encoders.9.norm_mha.weight | [256] | 256\n", - "encoder.encoders.9.norm_mha.bias | [256] | 256\n", - "encoder.encoders.9.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.9.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.9.norm_conv.weight | [256] | 256\n", - "encoder.encoders.9.norm_conv.bias | [256] | 256\n", - "encoder.encoders.9.norm_final.weight | [256] | 256\n", - "encoder.encoders.9.norm_final.bias | [256] | 256\n", - "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.9.concat_linear.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.10.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.10.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.10.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.10.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.10.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.10.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.10.norm_ff.weight | [256] | 256\n", - "encoder.encoders.10.norm_ff.bias | [256] | 256\n", - "encoder.encoders.10.norm_mha.weight | [256] | 256\n", - "encoder.encoders.10.norm_mha.bias | [256] | 256\n", - "encoder.encoders.10.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.10.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.10.norm_conv.weight | [256] | 256\n", - "encoder.encoders.10.norm_conv.bias | [256] | 256\n", - "encoder.encoders.10.norm_final.weight | [256] | 256\n", - "encoder.encoders.10.norm_final.bias | [256] | 256\n", - "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.10.concat_linear.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.pos_bias_u | [4, 64] | 256\n", - "encoder.encoders.11.self_attn.pos_bias_v | [4, 64] | 256\n", - "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536\n", - "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256\n", - "encoder.encoders.11.self_attn.linear_pos.weight | [256, 256] | 65536\n", - "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048\n", - "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | [256, 2048] | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | [2048] | 2048\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | [2048, 256] | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | [256] | 256\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | [512, 256, 1] | 131072\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | [512] | 512\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | [256, 1, 15] | 3840\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | [256] | 256\n", - "encoder.encoders.11.conv_module.norm.weight | [256] | 256\n", - "encoder.encoders.11.conv_module.norm.bias | [256] | 256\n", - "encoder.encoders.11.conv_module.norm._mean | [256] | 256\n", - "encoder.encoders.11.conv_module.norm._variance | [256] | 256\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | [256, 256, 1] | 65536\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | [256] | 256\n", - "encoder.encoders.11.norm_ff.weight | [256] | 256\n", - "encoder.encoders.11.norm_ff.bias | [256] | 256\n", - "encoder.encoders.11.norm_mha.weight | [256] | 256\n", - "encoder.encoders.11.norm_mha.bias | [256] | 256\n", - "encoder.encoders.11.norm_ff_macaron.weight | [256] | 256\n", - "encoder.encoders.11.norm_ff_macaron.bias | [256] | 256\n", - "encoder.encoders.11.norm_conv.weight | [256] | 256\n", - "encoder.encoders.11.norm_conv.bias | [256] | 256\n", - "encoder.encoders.11.norm_final.weight | [256] | 256\n", - "encoder.encoders.11.norm_final.bias | [256] | 256\n", - "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072\n", - "encoder.encoders.11.concat_linear.bias | [256] | 256\n", - "decoder.embed.0.weight | [4233, 256] | 1083648\n", - "decoder.after_norm.weight | [256] | 256\n", - "decoder.after_norm.bias | [256] | 256\n", - "decoder.output_layer.weight | [256, 4233] | 1083648\n", - "decoder.output_layer.bias | [4233] | 4233\n", - "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.0.norm1.weight | [256] | 256\n", - "decoder.decoders.0.norm1.bias | [256] | 256\n", - "decoder.decoders.0.norm2.weight | [256] | 256\n", - "decoder.decoders.0.norm2.bias | [256] | 256\n", - "decoder.decoders.0.norm3.weight | [256] | 256\n", - "decoder.decoders.0.norm3.bias | [256] | 256\n", - "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.0.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.0.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.1.norm1.weight | [256] | 256\n", - "decoder.decoders.1.norm1.bias | [256] | 256\n", - "decoder.decoders.1.norm2.weight | [256] | 256\n", - "decoder.decoders.1.norm2.bias | [256] | 256\n", - "decoder.decoders.1.norm3.weight | [256] | 256\n", - "decoder.decoders.1.norm3.bias | [256] | 256\n", - "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.1.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.1.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.2.norm1.weight | [256] | 256\n", - "decoder.decoders.2.norm1.bias | [256] | 256\n", - "decoder.decoders.2.norm2.weight | [256] | 256\n", - "decoder.decoders.2.norm2.bias | [256] | 256\n", - "decoder.decoders.2.norm3.weight | [256] | 256\n", - "decoder.decoders.2.norm3.bias | [256] | 256\n", - "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.2.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.2.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.3.norm1.weight | [256] | 256\n", - "decoder.decoders.3.norm1.bias | [256] | 256\n", - "decoder.decoders.3.norm2.weight | [256] | 256\n", - "decoder.decoders.3.norm2.bias | [256] | 256\n", - "decoder.decoders.3.norm3.weight | [256] | 256\n", - "decoder.decoders.3.norm3.bias | [256] | 256\n", - "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.3.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.3.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.4.norm1.weight | [256] | 256\n", - "decoder.decoders.4.norm1.bias | [256] | 256\n", - "decoder.decoders.4.norm2.weight | [256] | 256\n", - "decoder.decoders.4.norm2.bias | [256] | 256\n", - "decoder.decoders.4.norm3.weight | [256] | 256\n", - "decoder.decoders.4.norm3.bias | [256] | 256\n", - "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.4.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.4.concat_linear2.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256\n", - "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536\n", - "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256\n", - "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288\n", - "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048\n", - "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288\n", - "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256\n", - "decoder.decoders.5.norm1.weight | [256] | 256\n", - "decoder.decoders.5.norm1.bias | [256] | 256\n", - "decoder.decoders.5.norm2.weight | [256] | 256\n", - "decoder.decoders.5.norm2.bias | [256] | 256\n", - "decoder.decoders.5.norm3.weight | [256] | 256\n", - "decoder.decoders.5.norm3.bias | [256] | 256\n", - "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072\n", - "decoder.decoders.5.concat_linear1.bias | [256] | 256\n", - "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072\n", - "decoder.decoders.5.concat_linear2.bias | [256] | 256\n", - "ctc.ctc_lo.weight | [256, 4233] | 1083648\n", - "ctc.ctc_lo.bias | [4233] | 4233\n", - "Total parameters: 689, 49355442 elements.\n" - ] - } - ], - "source": [ - "summary(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "ruled-invitation", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "U2Model(\n", - " (encoder): ConformerEncoder(\n", - " (global_cmvn): GlobalCMVN()\n", - " (embed): Conv2dSubsampling4(\n", - " (pos_enc): RelPositionalEncoding(\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " )\n", - " (conv): Sequential(\n", - " (0): Conv2D(1, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n", - " (1): ReLU()\n", - " (2): Conv2D(256, 256, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)\n", - " (3): ReLU()\n", - " )\n", - " (out): Sequential(\n", - " (0): Linear(in_features=4864, out_features=256, dtype=float32)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (encoders): LayerList(\n", - " (0): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (1): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (2): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (3): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (4): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (5): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (6): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (7): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (8): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (9): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (10): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (11): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " (linear_pos): Linear(in_features=256, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1D(256, 512, kernel_size=[1], data_format=NCL)\n", - " (depthwise_conv): Conv1D(256, 256, kernel_size=[15], padding=7, groups=256, data_format=NCL)\n", - " (norm): BatchNorm1D(num_features=256, momentum=0.9, epsilon=1e-05)\n", - " (pointwise_conv2): Conv1D(256, 256, kernel_size=[1], data_format=NCL)\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_mha): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_ff_macaron): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_conv): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm_final): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " )\n", - " )\n", - " (decoder): TransformerDecoder(\n", - " (embed): Sequential(\n", - " (0): Embedding(4233, 256, sparse=False)\n", - " (1): PositionalEncoding(\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (output_layer): Linear(in_features=256, out_features=4233, dtype=float32)\n", - " (decoders): LayerList(\n", - " (0): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (1): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (2): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (3): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (4): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " (5): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_k): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_v): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (linear_out): Linear(in_features=256, out_features=256, dtype=float32)\n", - " (dropout): Dropout(p=0.0, axis=None, mode=upscale_in_train)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, dtype=float32)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (w_2): Linear(in_features=2048, out_features=256, dtype=float32)\n", - " )\n", - " (norm1): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm2): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (norm3): LayerNorm(normalized_shape=[256], epsilon=1e-12)\n", - " (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, dtype=float32)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, dtype=float32)\n", - " )\n", - " )\n", - " )\n", - " (ctc): CTCDecoder(\n", - " (ctc_lo): Linear(in_features=256, out_features=4233, dtype=float32)\n", - " (criterion): CTCLoss(\n", - " (loss): CTCLoss()\n", - " )\n", - " )\n", - " (criterion_att): LabelSmoothingLoss(\n", - " (criterion): KLDivLoss()\n", - " )\n", - ")\n" - ] - } - ], - "source": [ - "print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "fossil-means", - "metadata": {}, - "outputs": [], - "source": [ - "# load feat" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "fleet-despite", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "compute_cmvn_loader_test.ipynb encoder.npz\r\n", - "dataloader.ipynb hack_api_test.ipynb\r\n", - "dataloader_with_tokens_tokenids.ipynb jit_infer.ipynb\r\n", - "data.npz layer_norm_test.ipynb\r\n", - "decoder.npz Linear_test.ipynb\r\n", - "enc_0_ff_out.npz mask_and_masked_fill_test.ipynb\r\n", - "enc_0_norm_ff.npz model.npz\r\n", - "enc_0.npz position_embeding_check.ipynb\r\n", - "enc_0_selattn_out.npz python_test.ipynb\r\n", - "enc_2.npz train_test.ipynb\r\n", - "enc_all.npz u2_model.ipynb\r\n", - "enc_embed.npz\r\n" - ] - } - ], - "source": [ - "%ls .notebook" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "abroad-oracle", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['BAC009S0739W0246' 'BAC009S0727W0424' 'BAC009S0753W0412'\n", - " 'BAC009S0756W0206' 'BAC009S0740W0414' 'BAC009S0728W0426'\n", - " 'BAC009S0739W0214' 'BAC009S0753W0423' 'BAC009S0734W0201'\n", - " 'BAC009S0740W0427' 'BAC009S0730W0423' 'BAC009S0728W0367'\n", - " 'BAC009S0730W0418' 'BAC009S0727W0157' 'BAC009S0749W0409'\n", - " 'BAC009S0727W0418']\n", - "(16, 207, 80)\n", - "[[[ 8.994624 9.538309 9.191589 ... 10.507416 9.563305 8.256403 ]\n", - " [ 9.798841 10.405224 9.26511 ... 10.251211 9.543982 8.873768 ]\n", - " [10.6890745 10.395469 8.053548 ... 9.906749 10.064903 8.050915 ]\n", - " ...\n", - " [ 9.217986 9.65069 8.505259 ... 9.687183 8.742463 7.9865475]\n", - " [10.129122 9.935194 9.37982 ... 9.563894 9.825992 8.979543 ]\n", - " [ 9.095531 7.1338377 9.468001 ... 9.472748 9.021235 7.447914 ]]\n", - "\n", - " [[11.430976 10.671858 6.0841026 ... 9.382682 8.729745 7.5315614]\n", - " [ 9.731717 7.8104815 7.5714607 ... 10.043035 9.243595 7.3540792]\n", - " [10.65017 10.600604 8.467784 ... 9.281448 9.186885 8.070343 ]\n", - " ...\n", - " [ 9.096987 9.2637 8.075275 ... 8.431845 8.370505 8.002926 ]\n", - " [10.461651 10.147784 6.7693496 ... 9.779426 9.577453 8.080652 ]\n", - " [ 7.794432 5.621059 7.9750648 ... 9.997245 9.849678 8.031287 ]]\n", - "\n", - " [[ 7.3455667 7.896357 7.5795946 ... 11.631024 10.451254 9.123633 ]\n", - " [ 8.628678 8.4630575 7.499242 ... 12.415986 10.975749 8.9425745]\n", - " [ 9.831394 10.2812805 8.97241 ... 12.1386795 10.40175 9.005517 ]\n", - " ...\n", - " [ 7.089641 7.405548 6.8142557 ... 9.325196 9.273162 8.353427 ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]\n", - "\n", - " ...\n", - "\n", - " [[10.933237 10.464394 7.7202725 ... 10.348816 9.302338 7.1553144]\n", - " [10.449866 9.907033 9.029272 ... 9.952465 9.414051 7.559279 ]\n", - " [10.487655 9.81259 9.895244 ... 9.58662 9.341254 7.7849016]\n", - " ...\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]\n", - "\n", - " [[ 9.944384 9.585867 8.220328 ... 11.588647 11.045029 8.817075 ]\n", - " [ 7.678356 8.322397 7.533047 ... 11.055085 10.535685 9.27465 ]\n", - " [ 8.626197 9.675917 9.841045 ... 11.378827 10.922112 8.991444 ]\n", - " ...\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]\n", - "\n", - " [[ 8.107938 7.759043 6.710301 ... 12.650573 11.466156 11.061517 ]\n", - " [11.380332 11.222007 8.658889 ... 12.810616 12.222216 11.689288 ]\n", - " [10.677676 9.920579 8.046089 ... 13.572894 12.5624075 11.155033 ]\n", - " ...\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]\n", - " [ 0. 0. 0. ... 0. 0. 0. ]]]\n", - "[207 207 205 205 203 203 198 197 195 188 186 186 185 180 166 163]\n", - "[[2995 3116 1209 565 -1 -1]\n", - " [ 236 1176 331 66 3925 4077]\n", - " [2693 524 234 1145 366 -1]\n", - " [3875 4211 3062 700 -1 -1]\n", - " [ 272 987 1134 494 2959 -1]\n", - " [1936 3715 120 2553 2695 2710]\n", - " [ 25 1149 3930 -1 -1 -1]\n", - " [1753 1778 1237 482 3925 110]\n", - " [3703 2 565 3827 -1 -1]\n", - " [1150 2734 10 2478 3490 -1]\n", - " [ 426 811 95 489 144 -1]\n", - " [2313 2006 489 975 -1 -1]\n", - " [3702 3414 205 1488 2966 1347]\n", - " [ 70 1741 702 1666 -1 -1]\n", - " [ 703 1778 1030 849 -1 -1]\n", - " [ 814 1674 115 3827 -1 -1]]\n", - "[4 6 5 4 5 6 3 6 4 5 5 4 6 4 4 4]\n" - ] - } - ], - "source": [ - "data = np.load('.notebook/data.npz', allow_pickle=True)\n", - "keys=data['keys']\n", - "feat=data['feat']\n", - "feat_len=data['feat_len']\n", - "text=data['text']\n", - "text_len=data['text_len']\n", - "print(keys)\n", - "print(feat.shape)\n", - "print(feat)\n", - "print(feat_len)\n", - "print(text)\n", - "print(text_len)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "false-instrument", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "arctic-proxy", - "metadata": {}, - "outputs": [], - "source": [ - "# ['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n", - "# torch.Size([16, 207, 80])\n", - "# tensor([[[ 8.9946, 9.5383, 9.1916, ..., 10.5074, 9.5633, 8.2564],\n", - "# [ 9.7988, 10.4052, 9.2651, ..., 10.2512, 9.5440, 8.8738],\n", - "# [10.6891, 10.3955, 8.0535, ..., 9.9067, 10.0649, 8.0509],\n", - "# ...,\n", - "# [ 9.2180, 9.6507, 8.5053, ..., 9.6872, 8.7425, 7.9865],\n", - "# [10.1291, 9.9352, 9.3798, ..., 9.5639, 9.8260, 8.9795],\n", - "# [ 9.0955, 7.1338, 9.4680, ..., 9.4727, 9.0212, 7.4479]],\n", - "\n", - "# [[11.4310, 10.6719, 6.0841, ..., 9.3827, 8.7297, 7.5316],\n", - "# [ 9.7317, 7.8105, 7.5715, ..., 10.0430, 9.2436, 7.3541],\n", - "# [10.6502, 10.6006, 8.4678, ..., 9.2814, 9.1869, 8.0703],\n", - "# ...,\n", - "# [ 9.0970, 9.2637, 8.0753, ..., 8.4318, 8.3705, 8.0029],\n", - "# [10.4617, 10.1478, 6.7693, ..., 9.7794, 9.5775, 8.0807],\n", - "# [ 7.7944, 5.6211, 7.9751, ..., 9.9972, 9.8497, 8.0313]],\n", - "\n", - "# [[ 7.3456, 7.8964, 7.5796, ..., 11.6310, 10.4513, 9.1236],\n", - "# [ 8.6287, 8.4631, 7.4992, ..., 12.4160, 10.9757, 8.9426],\n", - "# [ 9.8314, 10.2813, 8.9724, ..., 12.1387, 10.4017, 9.0055],\n", - "# ...,\n", - "# [ 7.0896, 7.4055, 6.8143, ..., 9.3252, 9.2732, 8.3534],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - "# ...,\n", - "\n", - "# [[10.9332, 10.4644, 7.7203, ..., 10.3488, 9.3023, 7.1553],\n", - "# [10.4499, 9.9070, 9.0293, ..., 9.9525, 9.4141, 7.5593],\n", - "# [10.4877, 9.8126, 9.8952, ..., 9.5866, 9.3413, 7.7849],\n", - "# ...,\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - "# [[ 9.9444, 9.5859, 8.2203, ..., 11.5886, 11.0450, 8.8171],\n", - "# [ 7.6784, 8.3224, 7.5330, ..., 11.0551, 10.5357, 9.2746],\n", - "# [ 8.6262, 9.6759, 9.8410, ..., 11.3788, 10.9221, 8.9914],\n", - "# ...,\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - "# [[ 8.1079, 7.7590, 6.7103, ..., 12.6506, 11.4662, 11.0615],\n", - "# [11.3803, 11.2220, 8.6589, ..., 12.8106, 12.2222, 11.6893],\n", - "# [10.6777, 9.9206, 8.0461, ..., 13.5729, 12.5624, 11.1550],\n", - "# ...,\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - "# [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n", - "# tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n", - "# 166, 163], dtype=torch.int32)\n", - "# tensor([[2995, 3116, 1209, 565, -1, -1],\n", - "# [ 236, 1176, 331, 66, 3925, 4077],\n", - "# [2693, 524, 234, 1145, 366, -1],\n", - "# [3875, 4211, 3062, 700, -1, -1],\n", - "# [ 272, 987, 1134, 494, 2959, -1],\n", - "# [1936, 3715, 120, 2553, 2695, 2710],\n", - "# [ 25, 1149, 3930, -1, -1, -1],\n", - "# [1753, 1778, 1237, 482, 3925, 110],\n", - "# [3703, 2, 565, 3827, -1, -1],\n", - "# [1150, 2734, 10, 2478, 3490, -1],\n", - "# [ 426, 811, 95, 489, 144, -1],\n", - "# [2313, 2006, 489, 975, -1, -1],\n", - "# [3702, 3414, 205, 1488, 2966, 1347],\n", - "# [ 70, 1741, 702, 1666, -1, -1],\n", - "# [ 703, 1778, 1030, 849, -1, -1],\n", - "# [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n", - "# tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "seasonal-switch", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "defined-brooks", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "compute_cmvn_loader_test.ipynb\t encoder.npz\r\n", - "dataloader.ipynb\t\t hack_api_test.ipynb\r\n", - "dataloader_with_tokens_tokenids.ipynb jit_infer.ipynb\r\n", - "data.npz\t\t\t layer_norm_test.ipynb\r\n", - "decoder.npz\t\t\t Linear_test.ipynb\r\n", - "enc_0_ff_out.npz\t\t mask_and_masked_fill_test.ipynb\r\n", - "enc_0_norm_ff.npz\t\t model.npz\r\n", - "enc_0.npz\t\t\t position_embeding_check.ipynb\r\n", - "enc_0_selattn_out.npz\t\t python_test.ipynb\r\n", - "enc_2.npz\t\t\t train_test.ipynb\r\n", - "enc_all.npz\t\t\t u2_model.ipynb\r\n", - "enc_embed.npz\r\n" - ] - } - ], - "source": [ - "# load model param\n", - "!ls .notebook\n", - "data = np.load('.notebook/model.npz', allow_pickle=True)\n", - "state_dict = data['state'].item()\n", - "\n", - "for key, _ in model.state_dict().items():\n", - " if key not in state_dict:\n", - " print(f\"{key} not find.\")\n", - "\n", - "model.set_state_dict(state_dict)\n", - "\n", - "now_state_dict = model.state_dict()\n", - "for key, value in now_state_dict.items():\n", - " if not np.allclose(value.numpy(), state_dict[key]):\n", - " print(key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "exempt-viewer", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "confident-piano", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/framework.py:687: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " elif dtype == np.bool:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [142.48880005]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [41.84146118]) Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [377.33258057])\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:238: UserWarning: The dtype of left and right variables are not the same, left dtype is VarType.FP32, but right dtype is VarType.INT32, the right dtype will convert to VarType.FP32\n", - " format(lhs_dtype, rhs_dtype, lhs_dtype))\n" - ] - } - ], - "source": [ - "# compute loss\n", - "import paddle\n", - "feat=paddle.to_tensor(feat)\n", - "feat_len=paddle.to_tensor(feat_len, dtype='int64')\n", - "text=paddle.to_tensor(text, dtype='int64')\n", - "text_len=paddle.to_tensor(text_len, dtype='int64')\n", - "\n", - "model.eval()\n", - "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", - " text, text_len)\n", - "print(total_loss, attention_loss, ctc_loss )" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "better-senator", - "metadata": {}, - "outputs": [], - "source": [ - "# tensor(142.4888, device='cuda:0', grad_fn=) \n", - "# tensor(41.8415, device='cuda:0', grad_fn=) \n", - "# tensor(377.3326, device='cuda:0', grad_fn=)\n", - "# 142.4888 41.84146 377.33258" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "related-banking", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "olympic-problem", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[16, 51, 256]\n", - "[16, 1, 51]\n", - "Tensor(shape=[51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[-0.70194179, 0.56254166, 0.68803459, ..., 1.12373221, 0.78039235, 1.13693869],\n", - " [-0.77877808, 0.39126658, 0.71887815, ..., 1.25188220, 0.88616788, 1.31734526],\n", - " [-0.95908946, 0.63460249, 0.87671334, ..., 0.98183727, 0.74401081, 1.29032660],\n", - " ...,\n", - " [-1.07322502, 0.67236906, 0.92303109, ..., 0.90754563, 0.81767166, 1.32396567],\n", - " [-1.16541159, 0.68199694, 0.69394493, ..., 1.22383487, 0.80282891, 1.45065081],\n", - " [-1.27320945, 0.71458030, 0.75819558, ..., 0.94154912, 0.87748396, 1.26230514]])\n" - ] - } - ], - "source": [ - "# ecnoder\n", - "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n", - "print(encoder_out.shape)\n", - "print(encoder_mask.shape)\n", - "print(encoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "shaped-alaska", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "deepspeech examples README_cn.md\tsetup.sh tools\r\n", - "docs\t LICENSE README.md\t\ttests\t utils\r\n", - "env.sh\t log requirements.txt\tthird_party\r\n" - ] - } - ], - "source": [ - "!ls\n", - "data = np.load('.notebook/encoder.npz', allow_pickle=True)\n", - "torch_mask = data['mask']\n", - "torch_encoder_out = data['out']" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "federal-rover", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "None\n" - ] - } - ], - "source": [ - "print(np.testing.assert_equal(torch_mask, encoder_mask.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "regulated-interstate", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "False\n", - "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", - " 1.1369387 ]\n", - " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", - " 1.3173454 ]\n", - " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", - " 1.2903274 ]\n", - " ...\n", - " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", - " 1.3239657 ]\n", - " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", - " 1.4506509 ]\n", - " [-1.2732087 0.71458083 0.7581961 ... 0.9415482 0.877484\n", - " 1.2623053 ]]\n", - "----\n", - "[[-0.7019418 0.56254166 0.6880346 ... 1.1237322 0.78039235\n", - " 1.1369387 ]\n", - " [-0.7787781 0.39126658 0.71887815 ... 1.2518822 0.8861679\n", - " 1.3173453 ]\n", - " [-0.95908946 0.6346025 0.87671334 ... 0.9818373 0.7440108\n", - " 1.2903266 ]\n", - " ...\n", - " [-1.073225 0.67236906 0.9230311 ... 0.9075456 0.81767166\n", - " 1.3239657 ]\n", - " [-1.1654116 0.68199694 0.69394493 ... 1.2238349 0.8028289\n", - " 1.4506508 ]\n", - " [-1.2732095 0.7145803 0.7581956 ... 0.9415491 0.87748396\n", - " 1.2623051 ]]\n", - "True\n", - "False\n" - ] - } - ], - "source": [ - "print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n", - "print(torch_encoder_out[0])\n", - "print(\"----\")\n", - "print(encoder_out.numpy()[0])\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5, rtol=1e-6))\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6, rtol=1e-6))" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "proof-scheduling", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [377.33258057])\n", - "[1.]\n", - "[[ 3.16902876e+00 -1.51763987e-02 4.91095744e-02 ... -2.47971853e-03\n", - " -5.93360700e-03 -7.26609165e-03]\n", - " [-1.74184477e+00 7.75874173e-03 -4.49434854e-02 ... 9.92412097e-04\n", - " 2.46337592e-03 2.31892057e-03]\n", - " [-2.33343339e+00 1.30475955e-02 -2.66557075e-02 ... 2.27532350e-03\n", - " 5.76924905e-03 7.48788286e-03]\n", - " ...\n", - " [-4.30358458e+00 2.46054661e-02 -9.00950655e-02 ... 4.43156436e-03\n", - " 1.16122244e-02 1.44715561e-02]\n", - " [-3.36921120e+00 1.73153952e-02 -6.36872873e-02 ... 3.28363618e-03\n", - " 8.58010259e-03 1.07794888e-02]\n", - " [-6.62045336e+00 3.49955931e-02 -1.23962618e-01 ... 6.36671018e-03\n", - " 1.60814095e-02 2.03891303e-02]]\n", - "[-4.3777819e+00 2.3245810e-02 -9.3339294e-02 ... 4.2569344e-03\n", - " 1.0919910e-02 1.3787797e-02]\n" - ] - } - ], - "source": [ - "from paddle.nn import functional as F\n", - "def ctc_loss(logits,\n", - " labels,\n", - " input_lengths,\n", - " label_lengths,\n", - " blank=0,\n", - " reduction='mean',\n", - " norm_by_times=False):\n", - " loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,\n", - " input_lengths, label_lengths)\n", - " loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])\n", - " assert reduction in ['mean', 'sum', 'none']\n", - " if reduction == 'mean':\n", - " loss_out = paddle.mean(loss_out / label_lengths)\n", - " elif reduction == 'sum':\n", - " loss_out = paddle.sum(loss_out)\n", - " return loss_out\n", - "\n", - "F.ctc_loss = ctc_loss\n", - "\n", - "torch_mask_t = paddle.to_tensor(torch_mask, dtype='int64')\n", - "encoder_out_lens = torch_mask_t.squeeze(1).sum(1)\n", - "loss_ctc = model.ctc(paddle.to_tensor(torch_encoder_out), encoder_out_lens, text, text_len)\n", - "print(loss_ctc)\n", - "loss_ctc.backward()\n", - "print(loss_ctc.grad)\n", - "print(model.ctc.ctc_lo.weight.grad)\n", - "print(model.ctc.ctc_lo.bias.grad)\n", - "\n", - "\n", - "# tensor(377.3326, device='cuda:0', grad_fn=)\n", - "# None\n", - "# [[ 3.16902351e+00 -1.51765049e-02 4.91097234e-02 ... -2.47973716e-03\n", - "# -5.93366381e-03 -7.26613170e-03]\n", - "# [-1.74185038e+00 7.75875803e-03 -4.49435972e-02 ... 9.92415240e-04\n", - "# 2.46338220e-03 2.31891591e-03]\n", - "# [-2.33343077e+00 1.30476682e-02 -2.66557615e-02 ... 2.27533933e-03\n", - "# 5.76929189e-03 7.48792710e-03]\n", - "# ...\n", - "# [-4.30356789e+00 2.46056803e-02 -9.00955945e-02 ... 4.43160534e-03\n", - "# 1.16123557e-02 1.44716976e-02]\n", - "# [-3.36919212e+00 1.73155665e-02 -6.36875406e-02 ... 3.28367390e-03\n", - "# 8.58021621e-03 1.07796099e-02]\n", - "# [-6.62039661e+00 3.49958315e-02 -1.23963736e-01 ... 6.36674836e-03\n", - "# 1.60815325e-02 2.03892551e-02]]\n", - "# [-4.3777566e+00 2.3245990e-02 -9.3339972e-02 ... 4.2569702e-03\n", - "# 1.0920014e-02 1.3787906e-02]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "enclosed-consolidation", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "synthetic-hungarian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [41.84146118]) 0.0\n" - ] - } - ], - "source": [ - "loss_att, acc_att = model._calc_att_loss(paddle.to_tensor(torch_encoder_out), paddle.to_tensor(torch_mask),\n", - " text, text_len)\n", - "print(loss_att, acc_att)\n", - "#tensor(41.8416, device='cuda:0', grad_fn=) 0.0" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "indian-sweden", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 202, - "id": "marine-cuisine", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[-3.7638968e-01 -8.2272053e-01 7.4276292e-01 ... 3.4200522e-01\n", - " 1.5034772e-02 4.0337229e-01]\n", - " [-8.7386459e-01 -3.1389427e-01 4.1987866e-01 ... 3.7723729e-01\n", - " -1.4352810e-01 -1.0023664e+00]\n", - " [-4.3505096e-01 3.4504786e-02 -2.8710306e-01 ... 7.7274129e-02\n", - " -1.1672243e+00 -2.6848501e-01]\n", - " ...\n", - " [ 4.2471480e-01 5.8885634e-01 2.0203922e-02 ... 3.7405500e-01\n", - " 4.5470044e-02 -3.7139410e-01]\n", - " [-3.7978446e-01 -8.1084180e-01 7.5725085e-01 ... 2.6038891e-01\n", - " -7.9347193e-04 4.2537671e-01]\n", - " [-3.8279903e-01 -8.1206715e-01 7.4943429e-01 ... 2.6173013e-01\n", - " -1.0499060e-03 4.2678756e-01]]\n" - ] - } - ], - "source": [ - "data = np.load(\".notebook/decoder.npz\", allow_pickle=True)\n", - "torch_decoder_out = data['decoder_out']\n", - "print(torch_decoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 180, - "id": "several-result", - "metadata": {}, - "outputs": [], - "source": [ - "def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,\n", - " ignore_id: int):\n", - " \"\"\"Add and labels.\n", - " Args:\n", - " ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)\n", - " sos (int): index of \n", - " eos (int): index of \n", - " ignore_id (int): index of padding\n", - " Returns:\n", - " ys_in (paddle.Tensor) : (B, Lmax + 1)\n", - " ys_out (paddle.Tensor) : (B, Lmax + 1)\n", - " Examples:\n", - " >>> sos_id = 10\n", - " >>> eos_id = 11\n", - " >>> ignore_id = -1\n", - " >>> ys_pad\n", - " tensor([[ 1, 2, 3, 4, 5],\n", - " [ 4, 5, 6, -1, -1],\n", - " [ 7, 8, 9, -1, -1]], dtype=paddle.int32)\n", - " >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)\n", - " >>> ys_in\n", - " tensor([[10, 1, 2, 3, 4, 5],\n", - " [10, 4, 5, 6, 11, 11],\n", - " [10, 7, 8, 9, 11, 11]])\n", - " >>> ys_out\n", - " tensor([[ 1, 2, 3, 4, 5, 11],\n", - " [ 4, 5, 6, 11, -1, -1],\n", - " [ 7, 8, 9, 11, -1, -1]])\n", - " \"\"\"\n", - " # TODO(Hui Zhang): using comment code, \n", - " #_sos = paddle.to_tensor(\n", - " # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n", - " #_eos = paddle.to_tensor(\n", - " # [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)\n", - " #ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n", - " #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]\n", - " #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]\n", - " #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)\n", - " B = ys_pad.size(0)\n", - " _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos\n", - " _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos\n", - " ys_in = paddle.cat([_sos, ys_pad], dim=1)\n", - " mask_pad = (ys_in == ignore_id)\n", - " ys_in = ys_in.masked_fill(mask_pad, eos)\n", - " \n", - "\n", - " ys_out = paddle.cat([ys_pad, _eos], dim=1)\n", - " ys_out = ys_out.masked_fill(mask_pad, eos)\n", - " mask_eos = (ys_out == ignore_id)\n", - " ys_out = ys_out.masked_fill(mask_eos, eos)\n", - " ys_out = ys_out.masked_fill(mask_pad, ignore_id)\n", - " return ys_in, ys_out" - ] - }, - { - "cell_type": "code", - "execution_count": 181, - "id": "possible-bulgaria", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [[4232, 2995, 3116, 1209, 565 , 4232, 4232],\n", - " [4232, 236 , 1176, 331 , 66 , 3925, 4077],\n", - " [4232, 2693, 524 , 234 , 1145, 366 , 4232],\n", - " [4232, 3875, 4211, 3062, 700 , 4232, 4232],\n", - " [4232, 272 , 987 , 1134, 494 , 2959, 4232],\n", - " [4232, 1936, 3715, 120 , 2553, 2695, 2710],\n", - " [4232, 25 , 1149, 3930, 4232, 4232, 4232],\n", - " [4232, 1753, 1778, 1237, 482 , 3925, 110 ],\n", - " [4232, 3703, 2 , 565 , 3827, 4232, 4232],\n", - " [4232, 1150, 2734, 10 , 2478, 3490, 4232],\n", - " [4232, 426 , 811 , 95 , 489 , 144 , 4232],\n", - " [4232, 2313, 2006, 489 , 975 , 4232, 4232],\n", - " [4232, 3702, 3414, 205 , 1488, 2966, 1347],\n", - " [4232, 70 , 1741, 702 , 1666, 4232, 4232],\n", - " [4232, 703 , 1778, 1030, 849 , 4232, 4232],\n", - " [4232, 814 , 1674, 115 , 3827, 4232, 4232]])\n", - "Tensor(shape=[16, 7], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", - " [[2995, 3116, 1209, 565, 4232, -1 , -1 ],\n", - " [ 236, 1176, 331, 66 , 3925, 4077, 4232],\n", - " [2693, 524, 234, 1145, 366, 4232, -1 ],\n", - " [3875, 4211, 3062, 700, 4232, -1 , -1 ],\n", - " [ 272, 987, 1134, 494, 2959, 4232, -1 ],\n", - " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", - " [ 25 , 1149, 3930, 4232, -1 , -1 , -1 ],\n", - " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", - " [3703, 2 , 565, 3827, 4232, -1 , -1 ],\n", - " [1150, 2734, 10 , 2478, 3490, 4232, -1 ],\n", - " [ 426, 811, 95 , 489, 144, 4232, -1 ],\n", - " [2313, 2006, 489, 975, 4232, -1 , -1 ],\n", - " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", - " [ 70 , 1741, 702, 1666, 4232, -1 , -1 ],\n", - " [ 703, 1778, 1030, 849, 4232, -1 , -1 ],\n", - " [ 814, 1674, 115, 3827, 4232, -1 , -1 ]])\n" - ] - } - ], - "source": [ - "ys_pad = text\n", - "ys_pad_lens = text_len\n", - "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n", - " model.ignore_id)\n", - "ys_in_lens = ys_pad_lens + 1\n", - "print(ys_in_pad)\n", - "print(ys_out_pad)" - ] - }, - { - "cell_type": "code", - "execution_count": 285, - "id": "north-walter", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "False\n", - "True\n", - "False\n", - "[[-3.76389682e-01 -8.22720408e-01 7.42762923e-01 ... 3.42005253e-01\n", - " 1.50350705e-02 4.03372347e-01]\n", - " [-8.73864174e-01 -3.13894272e-01 4.19878662e-01 ... 3.77237231e-01\n", - " -1.43528014e-01 -1.00236630e+00]\n", - " [-4.35050905e-01 3.45046446e-02 -2.87102997e-01 ... 7.72742853e-02\n", - " -1.16722476e+00 -2.68485069e-01]\n", - " ...\n", - " [ 4.24714804e-01 5.88856399e-01 2.02039629e-02 ... 3.74054879e-01\n", - " 4.54700664e-02 -3.71394157e-01]\n", - " [-3.79784584e-01 -8.10841978e-01 7.57250786e-01 ... 2.60389000e-01\n", - " -7.93404877e-04 4.25376773e-01]\n", - " [-3.82798851e-01 -8.12067091e-01 7.49434292e-01 ... 2.61730075e-01\n", - " -1.04988366e-03 4.26787734e-01]]\n", - "---\n", - "[[-3.7638968e-01 -8.2272053e-01 7.4276292e-01 ... 3.4200522e-01\n", - " 1.5034772e-02 4.0337229e-01]\n", - " [-8.7386459e-01 -3.1389427e-01 4.1987866e-01 ... 3.7723729e-01\n", - " -1.4352810e-01 -1.0023664e+00]\n", - " [-4.3505096e-01 3.4504786e-02 -2.8710306e-01 ... 7.7274129e-02\n", - " -1.1672243e+00 -2.6848501e-01]\n", - " ...\n", - " [ 4.2471480e-01 5.8885634e-01 2.0203922e-02 ... 3.7405500e-01\n", - " 4.5470044e-02 -3.7139410e-01]\n", - " [-3.7978446e-01 -8.1084180e-01 7.5725085e-01 ... 2.6038891e-01\n", - " -7.9347193e-04 4.2537671e-01]\n", - " [-3.8279903e-01 -8.1206715e-01 7.4943429e-01 ... 2.6173013e-01\n", - " -1.0499060e-03 4.2678756e-01]]\n" - ] - } - ], - "source": [ - "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", - " ys_in_lens)\n", - "\n", - "print(np.allclose(decoder_out.numpy(), torch_decoder_out))\n", - "print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-6))\n", - "print(np.allclose(decoder_out.numpy(), torch_decoder_out, atol=1e-7))\n", - "print(decoder_out.numpy()[0])\n", - "print('---')\n", - "print(torch_decoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "armed-cowboy", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fifty-earth", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "proud-commonwealth", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 183, - "id": "assisted-fortune", - "metadata": {}, - "outputs": [], - "source": [ - "from paddle import nn\n", - "import paddle\n", - "from paddle.nn import functional as F\n", - "\n", - "class LabelSmoothingLoss(nn.Layer):\n", - "\n", - " def __init__(self,\n", - " size: int,\n", - " padding_idx: int,\n", - " smoothing: float,\n", - " normalize_length: bool=False):\n", - " super().__init__()\n", - " self.size = size\n", - " self.padding_idx = padding_idx\n", - " self.smoothing = smoothing\n", - " self.confidence = 1.0 - smoothing\n", - " self.normalize_length = normalize_length\n", - " self.criterion = nn.KLDivLoss(reduction=\"none\")\n", - "\n", - " def forward(self, x: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:\n", - " \"\"\"Compute loss between x and target.\n", - " The model outputs and data labels tensors are flatten to\n", - " (batch*seqlen, class) shape and a mask is applied to the\n", - " padding part which should not be calculated for loss.\n", - " \n", - " Args:\n", - " x (paddle.Tensor): prediction (batch, seqlen, class)\n", - " target (paddle.Tensor):\n", - " target signal masked with self.padding_id (batch, seqlen)\n", - " Returns:\n", - " loss (paddle.Tensor) : The KL loss, scalar float value\n", - " \"\"\"\n", - " B, T, D = paddle.shape(x)\n", - " assert D == self.size\n", - " x = x.reshape((-1, self.size))\n", - " target = target.reshape([-1])\n", - "\n", - " # use zeros_like instead of torch.no_grad() for true_dist,\n", - " # since no_grad() can not be exported by JIT\n", - " true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))\n", - " ignore = target == self.padding_idx # (B,)\n", - " print(self.smoothing / (self.size - 1))\n", - " print(true_dist)\n", - "\n", - " #target = target * (1 - ignore) # avoid -1 index\n", - " target = target.masked_fill(ignore, 0) # avoid -1 index\n", - " \n", - " \n", - " #true_dist += F.one_hot(target, self.size) * self.confidence\n", - " target_mask = F.one_hot(target, self.size)\n", - " true_dist *= (1 - target_mask)\n", - " true_dist += target_mask * self.confidence\n", - " \n", - "\n", - " kl = self.criterion(F.log_softmax(x, axis=1), true_dist)\n", - " \n", - " #TODO(Hui Zhang): sum not support bool type\n", - " #total = len(target) - int(ignore.sum())\n", - " total = len(target) - int(ignore.type_as(target).sum())\n", - " denom = total if self.normalize_length else B\n", - "\n", - " #numer = (kl * (1 - ignore)).sum()\n", - " numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n", - " return numer / denom\n" - ] - }, - { - "cell_type": "code", - "execution_count": 184, - "id": "weighted-delight", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.3629489603024576e-05\n", - "Tensor(shape=[112, 4233], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " ...,\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363],\n", - " [0.00002363, 0.00002363, 0.00002363, ..., 0.00002363, 0.00002363, 0.00002363]])\n", - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [41.84146118])\n", - "VarType.INT64\n" - ] - } - ], - "source": [ - "criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n", - "loss_att = criteron(paddle.to_tensor(torch_decoder_out), ys_out_pad.astype('int64'))\n", - "print(loss_att)\n", - "print(ys_out_pad.dtype)\n", - "# tensor(41.8416, device='cuda:0', grad_fn=)" - ] - }, - { - "cell_type": "code", - "execution_count": 286, - "id": "dress-shelter", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [41.84146118])\n", - "Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [41.84146118])\n", - "4233\n", - "-1\n", - "0.1\n", - "False\n" - ] - } - ], - "source": [ - "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", - " ys_in_lens)\n", - "\n", - "loss_att = model.criterion_att(paddle.to_tensor(torch_decoder_out), ys_out_pad)\n", - "print(loss_att)\n", - "\n", - "loss_att = model.criterion_att(decoder_out, ys_out_pad)\n", - "print(loss_att)\n", - "\n", - "print(model.criterion_att.size)\n", - "print(model.criterion_att.padding_idx)\n", - "print(model.criterion_att.smoothing)\n", - "print(model.criterion_att.normalize_length)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "growing-tooth", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "going-hungary", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "naughty-citizenship", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "experimental-emerald", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "adverse-saskatchewan", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "speaking-shelf", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import List\n", - "from typing import Optional\n", - "from typing import Tuple\n", - "\n", - "import paddle\n", - "from paddle import nn\n", - "from typeguard import check_argument_types\n", - "\n", - "from deepspeech.modules.activation import get_activation\n", - "from deepspeech.modules.attention import MultiHeadedAttention\n", - "from deepspeech.modules.attention import RelPositionMultiHeadedAttention\n", - "from deepspeech.modules.conformer_convolution import ConvolutionModule\n", - "from deepspeech.modules.embedding import PositionalEncoding\n", - "from deepspeech.modules.embedding import RelPositionalEncoding\n", - "from deepspeech.modules.encoder_layer import ConformerEncoderLayer\n", - "from deepspeech.modules.encoder_layer import TransformerEncoderLayer\n", - "from deepspeech.modules.mask import add_optional_chunk_mask\n", - "from deepspeech.modules.mask import make_non_pad_mask\n", - "from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward\n", - "from deepspeech.modules.subsampling import Conv2dSubsampling4\n", - "from deepspeech.modules.subsampling import Conv2dSubsampling6\n", - "from deepspeech.modules.subsampling import Conv2dSubsampling8\n", - "from deepspeech.modules.subsampling import LinearNoSubsampling\n", - "\n", - "class BaseEncoder(nn.Layer):\n", - " def __init__(\n", - " self,\n", - " input_size: int,\n", - " output_size: int=256,\n", - " attention_heads: int=4,\n", - " linear_units: int=2048,\n", - " num_blocks: int=6,\n", - " dropout_rate: float=0.1,\n", - " positional_dropout_rate: float=0.1,\n", - " attention_dropout_rate: float=0.0,\n", - " input_layer: str=\"conv2d\",\n", - " pos_enc_layer_type: str=\"abs_pos\",\n", - " normalize_before: bool=True,\n", - " concat_after: bool=False,\n", - " static_chunk_size: int=0,\n", - " use_dynamic_chunk: bool=False,\n", - " global_cmvn: paddle.nn.Layer=None,\n", - " use_dynamic_left_chunk: bool=False, ):\n", - " \"\"\"\n", - " Args:\n", - " input_size (int): input dim, d_feature\n", - " output_size (int): dimension of attention, d_model\n", - " attention_heads (int): the number of heads of multi head attention\n", - " linear_units (int): the hidden units number of position-wise feed\n", - " forward\n", - " num_blocks (int): the number of encoder blocks\n", - " dropout_rate (float): dropout rate\n", - " attention_dropout_rate (float): dropout rate in attention\n", - " positional_dropout_rate (float): dropout rate after adding\n", - " positional encoding\n", - " input_layer (str): input layer type.\n", - " optional [linear, conv2d, conv2d6, conv2d8]\n", - " pos_enc_layer_type (str): Encoder positional encoding layer type.\n", - " opitonal [abs_pos, scaled_abs_pos, rel_pos]\n", - " normalize_before (bool):\n", - " True: use layer_norm before each sub-block of a layer.\n", - " False: use layer_norm after each sub-block of a layer.\n", - " concat_after (bool): whether to concat attention layer's input\n", - " and output.\n", - " True: x -> x + linear(concat(x, att(x)))\n", - " False: x -> x + att(x)\n", - " static_chunk_size (int): chunk size for static chunk training and\n", - " decoding\n", - " use_dynamic_chunk (bool): whether use dynamic chunk size for\n", - " training or not, You can only use fixed chunk(chunk_size > 0)\n", - " or dyanmic chunk size(use_dynamic_chunk = True)\n", - " global_cmvn (Optional[paddle.nn.Layer]): Optional GlobalCMVN layer\n", - " use_dynamic_left_chunk (bool): whether use dynamic left chunk in\n", - " dynamic chunk training\n", - " \"\"\"\n", - " assert check_argument_types()\n", - " super().__init__()\n", - " self._output_size = output_size\n", - "\n", - " if pos_enc_layer_type == \"abs_pos\":\n", - " pos_enc_class = PositionalEncoding\n", - " elif pos_enc_layer_type == \"rel_pos\":\n", - " pos_enc_class = RelPositionalEncoding\n", - " else:\n", - " raise ValueError(\"unknown pos_enc_layer: \" + pos_enc_layer_type)\n", - "\n", - " if input_layer == \"linear\":\n", - " subsampling_class = LinearNoSubsampling\n", - " elif input_layer == \"conv2d\":\n", - " subsampling_class = Conv2dSubsampling4\n", - " elif input_layer == \"conv2d6\":\n", - " subsampling_class = Conv2dSubsampling6\n", - " elif input_layer == \"conv2d8\":\n", - " subsampling_class = Conv2dSubsampling8\n", - " else:\n", - " raise ValueError(\"unknown input_layer: \" + input_layer)\n", - "\n", - " self.global_cmvn = global_cmvn\n", - " self.embed = subsampling_class(\n", - " idim=input_size,\n", - " odim=output_size,\n", - " dropout_rate=dropout_rate,\n", - " pos_enc_class=pos_enc_class(\n", - " d_model=output_size, dropout_rate=positional_dropout_rate), )\n", - "\n", - " self.normalize_before = normalize_before\n", - " self.after_norm = nn.LayerNorm(output_size, epsilon=1e-12)\n", - " self.static_chunk_size = static_chunk_size\n", - " self.use_dynamic_chunk = use_dynamic_chunk\n", - " self.use_dynamic_left_chunk = use_dynamic_left_chunk\n", - "\n", - " def output_size(self) -> int:\n", - " return self._output_size\n", - "\n", - " def forward(\n", - " self,\n", - " xs: paddle.Tensor,\n", - " xs_lens: paddle.Tensor,\n", - " decoding_chunk_size: int=0,\n", - " num_decoding_left_chunks: int=-1,\n", - " ) -> Tuple[paddle.Tensor, paddle.Tensor]:\n", - " \"\"\"Embed positions in tensor.\n", - " Args:\n", - " xs: padded input tensor (B, L, D)\n", - " xs_lens: input length (B)\n", - " decoding_chunk_size: decoding chunk size for dynamic chunk\n", - " 0: default for training, use random dynamic chunk.\n", - " <0: for decoding, use full chunk.\n", - " >0: for decoding, use fixed chunk size as set.\n", - " num_decoding_left_chunks: number of left chunks, this is for decoding,\n", - " the chunk size is decoding_chunk_size.\n", - " >=0: use num_decoding_left_chunks\n", - " <0: use all left chunks\n", - " Returns:\n", - " encoder output tensor, lens and mask\n", - " \"\"\"\n", - " masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)\n", - "\n", - " if self.global_cmvn is not None:\n", - " xs = self.global_cmvn(xs)\n", - " #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor\n", - " xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)\n", - " #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor\n", - " masks = masks.astype(paddle.bool)\n", - " #TODO(Hui Zhang): mask_pad = ~masks\n", - " mask_pad = masks.logical_not()\n", - " chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,\n", - " decoding_chunk_size, self.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - " for layer in self.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " if self.normalize_before:\n", - " xs = self.after_norm(xs)\n", - " # Here we assume the mask is not changed in encoder layers, so just\n", - " # return the masks before encoder layers, and the masks will be used\n", - " # for cross attention with decoder later\n", - " return xs, masks" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "sharp-municipality", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "class ConformerEncoder(BaseEncoder):\n", - " \"\"\"Conformer encoder module.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " input_size: int,\n", - " output_size: int=256,\n", - " attention_heads: int=4,\n", - " linear_units: int=2048,\n", - " num_blocks: int=6,\n", - " dropout_rate: float=0.1,\n", - " positional_dropout_rate: float=0.1,\n", - " attention_dropout_rate: float=0.0,\n", - " input_layer: str=\"conv2d\",\n", - " pos_enc_layer_type: str=\"rel_pos\",\n", - " normalize_before: bool=True,\n", - " concat_after: bool=False,\n", - " static_chunk_size: int=0,\n", - " use_dynamic_chunk: bool=False,\n", - " global_cmvn: nn.Layer=None,\n", - " use_dynamic_left_chunk: bool=False,\n", - " positionwise_conv_kernel_size: int=1,\n", - " macaron_style: bool=True,\n", - " selfattention_layer_type: str=\"rel_selfattn\",\n", - " activation_type: str=\"swish\",\n", - " use_cnn_module: bool=True,\n", - " cnn_module_kernel: int=15,\n", - " causal: bool=False,\n", - " cnn_module_norm: str=\"batch_norm\", ):\n", - " \"\"\"Construct ConformerEncoder\n", - " Args:\n", - " input_size to use_dynamic_chunk, see in BaseEncoder\n", - " positionwise_conv_kernel_size (int): Kernel size of positionwise\n", - " conv1d layer.\n", - " macaron_style (bool): Whether to use macaron style for\n", - " positionwise layer.\n", - " selfattention_layer_type (str): Encoder attention layer type,\n", - " the parameter has no effect now, it's just for configure\n", - " compatibility.\n", - " activation_type (str): Encoder activation function type.\n", - " use_cnn_module (bool): Whether to use convolution module.\n", - " cnn_module_kernel (int): Kernel size of convolution module.\n", - " causal (bool): whether to use causal convolution or not.\n", - " cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm']\n", - " \"\"\"\n", - " assert check_argument_types()\n", - " super().__init__(input_size, output_size, attention_heads, linear_units,\n", - " num_blocks, dropout_rate, positional_dropout_rate,\n", - " attention_dropout_rate, input_layer,\n", - " pos_enc_layer_type, normalize_before, concat_after,\n", - " static_chunk_size, use_dynamic_chunk, global_cmvn,\n", - " use_dynamic_left_chunk)\n", - " activation = get_activation(activation_type)\n", - "\n", - " # self-attention module definition\n", - " encoder_selfattn_layer = RelPositionMultiHeadedAttention\n", - " encoder_selfattn_layer_args = (attention_heads, output_size,\n", - " attention_dropout_rate)\n", - " # feed-forward module definition\n", - " positionwise_layer = PositionwiseFeedForward\n", - " positionwise_layer_args = (output_size, linear_units, dropout_rate,\n", - " activation)\n", - " # convolution module definition\n", - " convolution_layer = ConvolutionModule\n", - " convolution_layer_args = (output_size, cnn_module_kernel, activation,\n", - " cnn_module_norm, causal)\n", - "\n", - " self.encoders = nn.LayerList([\n", - " ConformerEncoderLayer(\n", - " size=output_size,\n", - " self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),\n", - " feed_forward=positionwise_layer(*positionwise_layer_args),\n", - " feed_forward_macaron=positionwise_layer(\n", - " *positionwise_layer_args) if macaron_style else None,\n", - " conv_module=convolution_layer(*convolution_layer_args)\n", - " if use_cnn_module else None,\n", - " dropout_rate=dropout_rate,\n", - " normalize_before=normalize_before,\n", - " concat_after=concat_after) for _ in range(num_blocks)\n", - " ])\n" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "tutorial-syndication", - "metadata": {}, - "outputs": [], - "source": [ - "from deepspeech.frontend.utility import load_cmvn\n", - "from deepspeech.modules.cmvn import GlobalCMVN\n", - "\n", - "configs=cfg.model\n", - "mean, istd = load_cmvn(configs['cmvn_file'],\n", - " configs['cmvn_file_type'])\n", - "global_cmvn = GlobalCMVN(\n", - " paddle.to_tensor(mean, dtype=paddle.float),\n", - " paddle.to_tensor(istd, dtype=paddle.float))\n", - "\n", - "\n", - "input_dim = configs['input_dim']\n", - "vocab_size = configs['output_dim']\n", - "encoder_type = configs.get('encoder', 'transformer')\n", - " \n", - "encoder = ConformerEncoder(\n", - " input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "fuzzy-register", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] - } - ], - "source": [ - "o = global_cmvn(feat)\n", - "o2 = model.encoder.global_cmvn(feat)\n", - "print(np.allclose(o.numpy(), o2.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "explicit-triumph", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "humanitarian-belgium", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dying-proposal", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "honest-quick", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bound-cholesterol", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "viral-packaging", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 203, - "id": "balanced-locator", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 1, 207], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[True , True , True , ..., True , True , True ]],\n", - "\n", - " [[True , True , True , ..., True , True , True ]],\n", - "\n", - " [[True , True , True , ..., True , False, False]],\n", - "\n", - " ...,\n", - "\n", - " [[True , True , True , ..., False, False, False]],\n", - "\n", - " [[True , True , True , ..., False, False, False]],\n", - "\n", - " [[True , True , True , ..., False, False, False]]])\n" - ] - } - ], - "source": [ - "from deepspeech.modules.mask import make_non_pad_mask\n", - "from deepspeech.modules.mask import make_pad_mask\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "print(masks)" - ] - }, - { - "cell_type": "code", - "execution_count": 204, - "id": "induced-proposition", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 207, 80], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[-0.53697914, -0.19910523, -0.34997201, ..., -0.82427669, -1.02650309, -0.96300691],\n", - " [-0.04464225, 0.23176001, -0.32538742, ..., -0.90158713, -1.03248465, -0.75986791],\n", - " [ 0.50035292, 0.22691160, -0.73052198, ..., -1.00552964, -0.87123060, -1.03062117],\n", - " ...,\n", - " [-0.40023831, -0.14325078, -0.57947433, ..., -1.07178426, -1.28059900, -1.05180073],\n", - " [ 0.15755332, -0.00184949, -0.28702953, ..., -1.10898709, -0.94518697, -0.72506356],\n", - " [-0.47520429, -1.39415145, -0.25754252, ..., -1.13649082, -1.19430351, -1.22903371]],\n", - "\n", - " [[ 0.95454037, 0.36427975, -1.38908529, ..., -1.16366839, -1.28453600, -1.20151031],\n", - " [-0.08573537, -1.05785275, -0.89172721, ..., -0.96440506, -1.12547100, -1.25990939],\n", - " [ 0.47653601, 0.32886592, -0.59200549, ..., -1.19421589, -1.14302588, -1.02422845],\n", - " ...,\n", - " [-0.47431335, -0.33558893, -0.72325647, ..., -1.45058632, -1.39574063, -1.04641151],\n", - " [ 0.36112556, 0.10380996, -1.15994537, ..., -1.04394984, -1.02212358, -1.02083635],\n", - " [-1.27172923, -2.14601755, -0.75676596, ..., -0.97822225, -0.93785471, -1.03707945]],\n", - "\n", - " [[-1.54652190, -1.01517177, -0.88900733, ..., -0.48522446, -0.75163364, -0.67765164],\n", - " [-0.76100892, -0.73351598, -0.91587651, ..., -0.24835993, -0.58927339, -0.73722762],\n", - " [-0.02471367, 0.17015894, -0.42326337, ..., -0.33203802, -0.76695800, -0.71651691],\n", - " ...,\n", - " [-1.70319796, -1.25910866, -1.14492917, ..., -1.18101490, -1.11631835, -0.93108195],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0.64982772, 0.26116797, -0.84196597, ..., -0.87213463, -1.10728693, -1.32531130],\n", - " [ 0.35391113, -0.01584581, -0.40424931, ..., -0.99173468, -1.07270539, -1.19239008],\n", - " [ 0.37704495, -0.06278508, -0.11467686, ..., -1.10212946, -1.09524000, -1.11815071],\n", - " ...,\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", - "\n", - " [[ 0.04445776, -0.17546852, -0.67475224, ..., -0.49801198, -0.56782746, -0.77852231],\n", - " [-1.34279025, -0.80342549, -0.90457231, ..., -0.65901577, -0.72549772, -0.62796098],\n", - " [-0.76252806, -0.13071291, -0.13280024, ..., -0.56132573, -0.60587686, -0.72114766],\n", - " ...,\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]],\n", - "\n", - " [[-1.07980299, -1.08341801, -1.17969072, ..., -0.17757270, -0.43746525, -0.04000654],\n", - " [ 0.92353648, 0.63770926, -0.52810186, ..., -0.12927933, -0.20342292, 0.16655664],\n", - " [ 0.49337494, -0.00911332, -0.73301607, ..., 0.10074048, -0.09811471, -0.00923573],\n", - " ...,\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063],\n", - " [-6.04343224, -4.93973970, -3.42354989, ..., -3.99492049, -3.98687553, -3.67971063]]])\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "print(xs)" - ] - }, - { - "cell_type": "code", - "execution_count": 205, - "id": "cutting-julian", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 256, 51, 19], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0.00209083],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0.01194306, 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0.04610471, 0. ],\n", - " [0. , 0. , 0. , ..., 0.00967231, 0.04613467, 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.22816099, 0.24614786, 0.25304127, ..., 0.20401822, 0.23248228, 0.31190544],\n", - " [0.13587360, 0.28877240, 0.27991283, ..., 0.19210319, 0.20346391, 0.19934426],\n", - " [0.25739068, 0.39348233, 0.27877361, ..., 0.27482539, 0.19302306, 0.23810163],\n", - " ...,\n", - " [0.11939213, 0.28473237, 0.33082074, ..., 0.23838061, 0.22104350, 0.23905794],\n", - " [0.17387670, 0.20402060, 0.40263173, ..., 0.24782266, 0.26742202, 0.15426503],\n", - " [0. , 0.29080707, 0.27725950, ..., 0.17539823, 0.18478745, 0.22483408]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.35446781, 0.38861471, 0.39724261, ..., 0.38680089, 0.33568040, 0.34552398],\n", - " [0.41739127, 0.51038563, 0.41729912, ..., 0.33992639, 0.37081629, 0.35109508],\n", - " [0.36116859, 0.40744874, 0.48490953, ..., 0.34848654, 0.32321057, 0.35188958],\n", - " ...,\n", - " [0.23143977, 0.38021481, 0.51526314, ..., 0.36499465, 0.37411752, 0.39986172],\n", - " [0.34678638, 0.40238205, 0.50076538, ..., 0.36184520, 0.31596646, 0.36334658],\n", - " [0.36498138, 0.37943166, 0.51718897, ..., 0.31798238, 0.33656698, 0.34130475]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.01456045, 0.09447514, 0. , ..., 0. , 0. , 0. ],\n", - " [0.01500242, 0.02963220, 0. , ..., 0. , 0. , 0. ],\n", - " [0.03295187, 0. , 0. , ..., 0.04584959, 0.02043908, 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0.04425837],\n", - " [0. , 0. , 0.02556529, ..., 0. , 0.00900441, 0.04908358]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.11141267, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.33696529, 0.38526866, 0.32900479, ..., 0.28703830, 0.23351061, 0.19004467],\n", - " [0.13575366, 0.35783342, 0.33573425, ..., 0.22081660, 0.15854910, 0.13587447],\n", - " [0.21928655, 0.28900093, 0.28255141, ..., 0.20602837, 0.23927397, 0.21909429],\n", - " ...,\n", - " [0.23291890, 0.39096734, 0.36399242, ..., 0.20598020, 0.25373828, 0.23137446],\n", - " [0.18739152, 0.30793777, 0.30296701, ..., 0.27250600, 0.25191751, 0.20836820],\n", - " [0.22454213, 0.41402060, 0.54082996, ..., 0.31874508, 0.25079906, 0.25938687]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.26456982, 0.49519050, 0.56702250, ..., 0.30954638, 0.35292268, 0.32668519],\n", - " [0.21576807, 0.51833367, 0.49183372, ..., 0.36043224, 0.38523889, 0.36154741],\n", - " [0.20067888, 0.42784205, 0.52817714, ..., 0.31871423, 0.32452232, 0.31036487],\n", - " ...,\n", - " [0.49855131, 0.51001430, 0.52278662, ..., 0.36450142, 0.34338164, 0.33602941],\n", - " [0.41233343, 0.55517823, 0.52827710, ..., 0.40675971, 0.33873138, 0.36724189],\n", - " [0.40820011, 0.46187383, 0.47338152, ..., 0.38690975, 0.36039269, 0.38022059]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0.00578516, 0. , ..., 0.00748384, 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0.03035110, 0. , 0.00026720],\n", - " [0.00094807, 0. , 0. , ..., 0.00795512, 0. , 0. ],\n", - " ...,\n", - " [0.02032628, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0.01080076, 0. ],\n", - " [0.18470290, 0. , 0. , ..., 0.05058352, 0.09475817, 0.05914564]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.38708323, 0.28021947, 0.35892880, ..., 0.16595127, 0.16031364, 0.21136315],\n", - " [0.15595171, 0.30544323, 0.24666184, ..., 0.22675267, 0.25765014, 0.19682154],\n", - " [0.29517862, 0.41209796, 0.20063159, ..., 0.17595036, 0.22536841, 0.22214051],\n", - " ...,\n", - " [0.24744980, 0.26258564, 0.38654143, ..., 0.23620218, 0.23157144, 0.18514194],\n", - " [0.25714791, 0.29592845, 0.47744542, ..., 0.23545510, 0.25072727, 0.20976165],\n", - " [1.20154655, 0.84644288, 0.73385584, ..., 1.02517247, 0.95309550, 1.00134516]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.45013186, 0.47484034, 0.40540054, ..., 0.19346163, 0.17825794, 0.14776605],\n", - " [0.47545874, 0.48186573, 0.36760187, ..., 0.27809089, 0.32997063, 0.32337096],\n", - " [0.46160024, 0.40050328, 0.39060861, ..., 0.36612910, 0.35242686, 0.29738861],\n", - " ...,\n", - " [0.55148494, 0.51017821, 0.40132499, ..., 0.38948193, 0.35737294, 0.33088297],\n", - " [0.41972569, 0.45475486, 0.45320493, ..., 0.38343129, 0.40125814, 0.36180776],\n", - " [0.34279808, 0.31606171, 0.44701228, ..., 0.21665487, 0.23984617, 0.23903391]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.04178291, 0. , 0.01580476, ..., 0. , 0.02250817, 0. ],\n", - " [0.04323414, 0.07786420, 0. , ..., 0.01634724, 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.03209178, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.13563479, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0. , 0.25187218, 0.24979387, ..., 0.24774717, 0.22354351, 0.19149347],\n", - " [0.16540922, 0.19585510, 0.19812922, ..., 0.27344131, 0.20928150, 0.26150429],\n", - " [0.10494646, 0.06329897, 0.33843631, ..., 0.25138417, 0.12470355, 0.23926635],\n", - " ...,\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.11428106, 0.45667490, 0.46820879, ..., 0.32057840, 0.33578536, 0.39012644],\n", - " [0.10441341, 0.45739070, 0.46107352, ..., 0.38467997, 0.38291249, 0.36685589],\n", - " [0.19867736, 0.35519636, 0.44313061, ..., 0.40679252, 0.38067645, 0.30645671],\n", - " ...,\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.02465414, 0. , 0. , ..., 0. , 0. , 0.03390232],\n", - " [0. , 0. , 0.01830704, ..., 0.05166877, 0.00948385, 0.07453502],\n", - " [0.09921519, 0. , 0.01587192, ..., 0.01620276, 0.05140074, 0.00192392],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.40034360, 0.25306445, 0.20217699, ..., 0.09816189, 0.07064310, 0.04974059],\n", - " [0.12567598, 0.21030979, 0.11181555, ..., 0.04278110, 0.11968569, 0.12005232],\n", - " [0.28786880, 0.24030517, 0.22565845, ..., 0. , 0.06418110, 0.05872961],\n", - " ...,\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.38404641, 0.30990323, 0.37156230, ..., 0.18125033, 0.15050662, 0.19619957],\n", - " [0.47285745, 0.40528792, 0.39718056, ..., 0.24709940, 0.04565683, 0.11500744],\n", - " [0.32620737, 0.30072594, 0.30477354, ..., 0.23529193, 0.21356541, 0.16985542],\n", - " ...,\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]],\n", - "\n", - "\n", - " [[[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.03343770, 0.00123780, 0.05297198, ..., 0.07271163, 0.08656286, 0.14493589],\n", - " [0.11043239, 0.06143146, 0.06362963, ..., 0.08127750, 0.06259022, 0.08315435],\n", - " [0.01767678, 0.00201111, 0.07875030, ..., 0.06963293, 0.08979890, 0.05326346],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.10033827, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.15627117, 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0.05144687, 0. , 0. , ..., 0. , 0. , 0.00436414],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " ...,\n", - "\n", - " [[0.25142455, 0.45964020, 0.37346074, ..., 0.04763087, 0. , 0. ],\n", - " [0.19760093, 0.26626948, 0.11190540, ..., 0.03044968, 0. , 0. ],\n", - " [0.16340607, 0.32938001, 0.25689697, ..., 0.05569421, 0. , 0. ],\n", - " ...,\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163],\n", - " [1.12572610, 0.87340784, 0.78169060, ..., 1.04576325, 1.00935984, 1.02209163]],\n", - "\n", - " [[0. , 0. , 0. , ..., 0. , 0.02218930, 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0.02848953],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", - "\n", - " [[0.25810039, 0.63016868, 0.37037861, ..., 0.18704373, 0.08269356, 0.09912672],\n", - " [0.17292863, 0.50678611, 0.40738991, ..., 0.16006103, 0.11725381, 0.09940521],\n", - " [0.24175072, 0.41616210, 0.41256818, ..., 0.13519743, 0.07912572, 0.12846369],\n", - " ...,\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700],\n", - " [1.44883108, 1.02119160, 0.94472742, ..., 1.23630035, 1.21888959, 1.23804700]]]])\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "\n", - "#xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "# print(xs)\n", - "\n", - "x = xs.unsqueeze(1)\n", - "x = model.encoder.embed.conv(x)\n", - "print(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 206, - "id": "friendly-nightlife", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[[-0.03426375, 0.14291267, -0.06718873, ..., 0.09064753, 0.01809387, -0.04340880],\n", - " [-0.05007839, 0.11054724, -0.10399298, ..., 0.11457238, 0.04244684, -0.01249714],\n", - " [-0.10695291, 0.16910909, -0.08352133, ..., 0.07710276, 0.01168563, -0.03584499],\n", - " ...,\n", - " [-0.06060536, 0.14455931, -0.05470302, ..., 0.05364908, 0.03033342, -0.02610814],\n", - " [-0.08505894, 0.13611752, -0.11132983, ..., 0.13079923, 0.01580139, -0.02281028],\n", - " [-0.10604677, 0.14714901, -0.10885533, ..., 0.08543444, 0.03719445, -0.04634233]],\n", - "\n", - " [[-0.12392755, 0.14486063, -0.05674079, ..., 0.02573164, 0.03128851, 0.00545091],\n", - " [-0.04775286, 0.08473608, -0.08507854, ..., 0.04573154, 0.04240163, 0.01053247],\n", - " [-0.05940291, 0.10023535, -0.08143730, ..., 0.03596500, 0.01673085, 0.02089563],\n", - " ...,\n", - " [-0.09222981, 0.15823206, -0.07700447, ..., 0.08122957, 0.03136991, -0.00646474],\n", - " [-0.07331756, 0.14482647, -0.07838815, ..., 0.10869440, 0.01356864, -0.02777974],\n", - " [-0.07937264, 0.20143102, -0.05544947, ..., 0.10287814, 0.00608235, -0.04799180]],\n", - "\n", - " [[-0.03670349, 0.08931590, -0.08718812, ..., 0.01314050, 0.00642052, 0.00573716],\n", - " [ 0.01089254, 0.11146393, -0.10263617, ..., 0.05070438, 0.01960694, 0.03521532],\n", - " [-0.02182280, 0.11443964, -0.06678198, ..., 0.04327708, 0.00861394, 0.02871092],\n", - " ...,\n", - " [-0.06792898, 0.14376275, -0.07899005, ..., 0.11248926, 0.03208683, -0.03264240],\n", - " [-0.07884051, 0.17024788, -0.08583611, ..., 0.09028331, 0.03588808, -0.02075090],\n", - " [-0.13792302, 0.27163863, -0.23930418, ..., 0.13391261, 0.07521040, -0.08621951]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.02446348, 0.11595841, -0.03591986, ..., 0.06288970, 0.02895011, -0.06532725],\n", - " [-0.05378424, 0.12607370, -0.09023033, ..., 0.09078894, 0.01035743, 0.03701983],\n", - " [-0.04566649, 0.14275314, -0.06686870, ..., 0.09890588, -0.00612222, 0.03439377],\n", - " ...,\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]],\n", - "\n", - " [[-0.01012144, 0.03909408, -0.07077143, ..., 0.00452683, -0.01377654, 0.02897627],\n", - " [-0.00519154, 0.03594019, -0.06831125, ..., 0.05693541, -0.00406374, 0.04561640],\n", - " [-0.01762631, 0.00500899, -0.05886075, ..., 0.02112178, -0.00729015, 0.02782153],\n", - " ...,\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]],\n", - "\n", - " [[-0.03411558, -0.04318277, -0.08497842, ..., -0.04886402, 0.04296734, 0.06151697],\n", - " [ 0.00263296, -0.06913657, -0.08993219, ..., -0.00149064, 0.05696633, 0.03304394],\n", - " [-0.01818341, -0.01178640, -0.09679577, ..., -0.00870231, 0.00362198, 0.01916483],\n", - " ...,\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698],\n", - " [-0.31763062, 0.53700209, -0.26335421, ..., 0.39182857, 0.00337184, -0.18293698]]])\n", - "Tensor(shape=[16, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[[-0.54821998, 2.28660274, -1.07501972, ..., 1.45036042, 0.28950194, -0.69454080],\n", - " [-0.80125421, 1.76875579, -1.66388774, ..., 1.83315802, 0.67914939, -0.19995420],\n", - " [-1.71124649, 2.70574546, -1.33634126, ..., 1.23364413, 0.18697014, -0.57351983],\n", - " ...,\n", - " [-0.96968573, 2.31294894, -0.87524825, ..., 0.85838526, 0.48533469, -0.41773027],\n", - " [-1.36094308, 2.17788029, -1.78127730, ..., 2.09278774, 0.25282228, -0.36496443],\n", - " [-1.69674826, 2.35438418, -1.74168527, ..., 1.36695099, 0.59511113, -0.74147725]],\n", - "\n", - " [[-1.98284078, 2.31777000, -0.90785271, ..., 0.41170627, 0.50061619, 0.08721463],\n", - " [-0.76404583, 1.35577726, -1.36125672, ..., 0.73170459, 0.67842603, 0.16851945],\n", - " [-0.95044655, 1.60376561, -1.30299675, ..., 0.57544005, 0.26769355, 0.33433008],\n", - " ...,\n", - " [-1.47567701, 2.53171301, -1.23207152, ..., 1.29967308, 0.50191855, -0.10343577],\n", - " [-1.17308092, 2.31722355, -1.25421047, ..., 1.73911047, 0.21709818, -0.44447583],\n", - " [-1.26996231, 3.22289634, -0.88719147, ..., 1.64605021, 0.09731755, -0.76786882]],\n", - "\n", - " [[-0.58725590, 1.42905438, -1.39500988, ..., 0.21024795, 0.10272825, 0.09179455],\n", - " [ 0.17428070, 1.78342295, -1.64217877, ..., 0.81127012, 0.31371105, 0.56344515],\n", - " [-0.34916472, 1.83103430, -1.06851172, ..., 0.69243336, 0.13782299, 0.45937473],\n", - " ...,\n", - " [-1.08686376, 2.30020404, -1.26384079, ..., 1.79982817, 0.51338923, -0.52227837],\n", - " [-1.26144814, 2.72396612, -1.37337780, ..., 1.44453299, 0.57420933, -0.33201432],\n", - " [-2.20676827, 4.34621811, -3.82886696, ..., 2.14260173, 1.20336640, -1.37951219]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.39141566, 1.85533464, -0.57471782, ..., 1.00623512, 0.46320182, -1.04523599],\n", - " [-0.86054784, 2.01717925, -1.44368529, ..., 1.45262301, 0.16571884, 0.59231722],\n", - " [-0.73066384, 2.28405023, -1.06989920, ..., 1.58249414, -0.09795550, 0.55030036],\n", - " ...,\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]],\n", - "\n", - " [[-0.16194311, 0.62550521, -1.13234293, ..., 0.07242929, -0.22042468, 0.46362036],\n", - " [-0.08306468, 0.57504302, -1.09298003, ..., 0.91096652, -0.06501988, 0.72986233],\n", - " [-0.28202093, 0.08014385, -0.94177192, ..., 0.33794850, -0.11664233, 0.44514441],\n", - " ...,\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]],\n", - "\n", - " [[-0.54584920, -0.69092435, -1.35965478, ..., -0.78182435, 0.68747747, 0.98427159],\n", - " [ 0.04212743, -1.10618520, -1.43891501, ..., -0.02385022, 0.91146135, 0.52870303],\n", - " [-0.29093450, -0.18858244, -1.54873240, ..., -0.13923697, 0.05795169, 0.30663735],\n", - " ...,\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170],\n", - " [-5.08208990, 8.59203339, -4.21366739, ..., 6.26925707, 0.05394945, -2.92699170]]])\n", - "Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[ 0. , 1. , 0. , ..., 1. , 0. , 1. ],\n", - " [ 0.84147102, 0.54030228, 0.80196184, ..., 1. , 0.00010746, 1. ],\n", - " [ 0.90929747, -0.41614681, 0.95814437, ..., 1. , 0.00021492, 1. ],\n", - " ...,\n", - " [-0.76825470, -0.64014435, 0.63279730, ..., 0.99998462, 0.00515809, 0.99998671],\n", - " [-0.95375264, 0.30059254, 0.99899054, ..., 0.99998397, 0.00526555, 0.99998611],\n", - " [-0.26237485, 0.96496606, 0.56074661, ..., 0.99998331, 0.00537301, 0.99998558]]])\n" - ] - } - ], - "source": [ - "b, c, t, f = paddle.shape(x)\n", - "x = model.encoder.embed.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))\n", - "print(x)\n", - "x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n", - "print(x)\n", - "print(pos_emb)" - ] - }, - { - "cell_type": "code", - "execution_count": 207, - "id": "guilty-cache", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tensor(shape=[1, 51, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n", - " [[[ 0. , 1. , 0. , ..., 1. , 0. , 1. ],\n", - " [ 0.84147102, 0.54030228, 0.80196184, ..., 1. , 0.00010746, 1. ],\n", - " [ 0.90929747, -0.41614681, 0.95814437, ..., 1. , 0.00021492, 1. ],\n", - " ...,\n", - " [-0.76825470, -0.64014435, 0.63279730, ..., 0.99998462, 0.00515809, 0.99998671],\n", - " [-0.95375264, 0.30059254, 0.99899054, ..., 0.99998397, 0.00526555, 0.99998611],\n", - " [-0.26237485, 0.96496606, 0.56074661, ..., 0.99998331, 0.00537301, 0.99998558]]])\n" - ] - } - ], - "source": [ - "print(pos_emb)" - ] - }, - { - "cell_type": "code", - "execution_count": 208, - "id": "iraqi-payday", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[ 0.0000000e+00 1.0000000e+00 0.0000000e+00 ... 1.0000000e+00\n", - " 0.0000000e+00 1.0000000e+00]\n", - " [ 8.4147096e-01 5.4030234e-01 8.0196178e-01 ... 1.0000000e+00\n", - " 1.0746076e-04 1.0000000e+00]\n", - " [ 9.0929741e-01 -4.1614684e-01 9.5814437e-01 ... 1.0000000e+00\n", - " 2.1492151e-04 1.0000000e+00]\n", - " ...\n", - " [ 9.5625257e-01 -2.9254240e-01 4.8925215e-01 ... 8.3807874e-01\n", - " 5.1154459e-01 8.5925674e-01]\n", - " [ 2.7049953e-01 -9.6272010e-01 9.9170387e-01 ... 8.3801574e-01\n", - " 5.1163691e-01 8.5920173e-01]\n", - " [-6.6394955e-01 -7.4777740e-01 6.9544029e-01 ... 8.3795273e-01\n", - " 5.1172924e-01 8.5914677e-01]]]\n", - "[1, 5000, 256]\n" - ] - } - ], - "source": [ - "import torch\n", - "import math\n", - "import numpy as np\n", - "\n", - "max_len=5000\n", - "d_model=256\n", - "\n", - "pe = torch.zeros(max_len, d_model)\n", - "position = torch.arange(0, max_len,\n", - " dtype=torch.float32).unsqueeze(1)\n", - "toruch_position = position\n", - "div_term = torch.exp(\n", - " torch.arange(0, d_model, 2, dtype=torch.float32) *\n", - " -(math.log(10000.0) / d_model))\n", - "tourch_div_term = div_term.cpu().detach().numpy()\n", - "\n", - "torhc_sin = torch.sin(position * div_term)\n", - "torhc_cos = torch.cos(position * div_term)\n", - "\n", - "np_sin = np.sin((position * div_term).cpu().detach().numpy())\n", - "np_cos = np.cos((position * div_term).cpu().detach().numpy())\n", - "pe[:, 0::2] = torhc_sin\n", - "pe[:, 1::2] = torhc_cos\n", - "pe = pe.unsqueeze(0) \n", - "tourch_pe = pe.cpu().detach().numpy()\n", - "print(tourch_pe)\n", - "bak_pe = model.encoder.embed.pos_enc.pe\n", - "print(bak_pe.shape)\n", - "model.encoder.embed.pos_enc.pe = paddle.to_tensor(tourch_pe)" - ] - }, - { - "cell_type": "code", - "execution_count": 210, - "id": "exempt-cloud", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "#print(xs)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "composite-involvement", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 269, - "id": "handed-harris", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "True\n", - "True\n", - "True\n", - "True\n", - "True\n", - "False\n", - "True\n", - "[256, 2048]\n", - "[2048]\n", - "[2048, 256]\n", - "[256]\n", - "--------ff-------\n", - "True\n", - "False\n", - "False\n", - "False\n", - "False\n", - "True\n", - "linear_714.w_0 True\n", - "linear_714.b_0 True\n", - "linear_715.w_0 True\n", - "linear_715.b_0 True\n", - "False\n", - "True\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "masks = masks.astype(paddle.bool)\n", - "mask_pad = masks.logical_not()\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", - " decoding_chunk_size, model.encoder.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "#print(chunk_masks)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "torch_chunk_masks = data['chunk_masks']\n", - "torch_mask_pad = data['mask_pad']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", - "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", - "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", - "\n", - "for layer in model.encoder.encoders:\n", - " #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " print(layer.feed_forward_macaron is not None)\n", - " print(layer.normalize_before)\n", - " \n", - " data = np.load('.notebook/enc_0_norm_ff.npz')\n", - " t_norm_ff = data['norm_ff']\n", - " t_xs = data['xs']\n", - " \n", - " \n", - " x = xs\n", - " print(np.allclose(t_xs, x.numpy()))\n", - " residual = x\n", - " print(np.allclose(t_xs, residual.numpy()))\n", - " x_nrom = layer.norm_ff_macaron(x)\n", - " print(np.allclose(t.numpy(), x_nrom.numpy()))\n", - " print(np.allclose(t_norm_ff, x_nrom.numpy()))\n", - "# for n, p in layer.norm_ff_macaron.state_dict().items():\n", - "# print(n, p)\n", - "# pass\n", - "\n", - " layer.eval()\n", - " x_nrom = paddle.to_tensor(t_norm_ff)\n", - " print(np.allclose(t_norm_ff, x_nrom.numpy()))\n", - " x = residual + layer.ff_scale * layer.feed_forward_macaron(x_nrom)\n", - " \n", - " ps=[]\n", - " for n, p in layer.feed_forward_macaron.state_dict().items():\n", - " #print(n, p)\n", - " ps.append(p)\n", - " print(p.shape)\n", - " pass\n", - "\n", - " x_nrom = paddle.to_tensor(t_norm_ff)\n", - " ff_l_x = layer.feed_forward_macaron.w_1(x_nrom)\n", - " ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n", - " ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n", - " data = np.load('.notebook/enc_0_ff_out.npz', allow_pickle=True)\n", - " t_norm_ff = data['norm_ff']\n", - " t_ff_out = data['ff_out']\n", - " t_ff_l_x = data['ff_l_x']\n", - " t_ff_l_a_x = data['ff_l_a_x']\n", - " t_ff_l_a_l_x = data['ff_l_a_l_x']\n", - " t_ps = data['ps']\n", - " \n", - " print(\"--------ff-------\")\n", - " print(np.allclose(x_nrom.numpy(), t_norm_ff))\n", - " print(np.allclose(x.numpy(), t_ff_out))\n", - " print(np.allclose(ff_l_x.numpy(), t_ff_l_x))\n", - " print(np.allclose(ff_l_a_x.numpy(), t_ff_l_a_x))\n", - " print(np.allclose(ff_l_a_l_x.numpy(), t_ff_l_a_l_x))\n", - " \n", - " print(np.allclose(ff_l_x.numpy(), t_ff_l_x, atol=1e-6))\n", - " for p, t_p in zip(ps, t_ps):\n", - " print(p.name, np.allclose(p.numpy(), t_p.T))\n", - " \n", - " \n", - "# residual = x\n", - "# x = layer.norm_mha(x)\n", - "# x_q = x\n", - " \n", - " data = np.load('.notebook/enc_0_selattn_out.npz', allow_pickle=True)\n", - " tx_q = data['x_q']\n", - " tx = data['x']\n", - " tpos_emb=data['pos_emb']\n", - " tmask=data['mask']\n", - " tt_x_att=data['x_att']\n", - " x_q = paddle.to_tensor(tx_q)\n", - " x = paddle.to_tensor(tx)\n", - " pos_emb = paddle.to_tensor(tpos_emb)\n", - " mask = paddle.to_tensor(tmask)\n", - " \n", - " x_att = layer.self_attn(x_q, x, x, pos_emb, mask)\n", - " print(np.allclose(x_att.numpy(), t_x_att))\n", - " print(np.allclose(x_att.numpy(), t_x_att, atol=1e-6))\n", - " \n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 270, - "id": "sonic-thumb", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "False\n", - "True\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "masks = masks.astype(paddle.bool)\n", - "mask_pad = masks.logical_not()\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", - " decoding_chunk_size, model.encoder.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "#print(chunk_masks)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "torch_chunk_masks = data['chunk_masks']\n", - "torch_mask_pad = data['mask_pad']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", - "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", - "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", - "\n", - "\n", - "for layer in model.encoder.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " break\n", - "data = np.load('.notebook/enc_0.npz')\n", - "torch_xs = data['enc_0']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(xs.numpy(), torch_xs, atol=1e-6))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 273, - "id": "brave-latino", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "--------layers_______\n", - "False\n", - "True\n", - "[[-0.70194244 0.56254214 0.6880346 ... 1.1237319 0.7803924\n", - " 1.1369387 ]\n", - " [-0.7787783 0.3912667 0.71887773 ... 1.251882 0.886168\n", - " 1.3173451 ]\n", - " [-0.95908964 0.6346029 0.87671334 ... 0.98183745 0.7440111\n", - " 1.2903278 ]\n", - " ...\n", - " [-1.0732255 0.67236906 0.92303115 ... 0.9075458 0.8176712\n", - " 1.3239655 ]\n", - " [-1.1654118 0.6819967 0.6939453 ... 1.2238353 0.8028295\n", - " 1.4506507 ]\n", - " [-1.2732092 0.7145806 0.75819594 ... 0.94154835 0.8774845\n", - " 1.2623049 ]]\n", - "xxxxxx\n", - "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", - " 1.1369387 ]\n", - " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", - " 1.3173454 ]\n", - " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", - " 1.2903274 ]\n", - " ...\n", - " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", - " 1.3239657 ]\n", - " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", - " 1.4506509 ]\n", - " [-1.273209 0.71458095 0.75819623 ... 0.9415484 0.8774842\n", - " 1.2623055 ]]\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = make_non_pad_mask(feat_len).unsqueeze(1)\n", - "\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks.type_as(xs), offset=0)\n", - "masks = masks.astype(paddle.bool)\n", - "mask_pad = masks.logical_not()\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, masks, model.encoder.use_dynamic_chunk, model.encoder.use_dynamic_left_chunk,\n", - " decoding_chunk_size, model.encoder.static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "#print(chunk_masks)\n", - "data = np.load(\".notebook/enc_embed.npz\")\n", - "torch_pos_emb=data['pos_emb']\n", - "torch_xs = data['embed_out']\n", - "torch_chunk_masks = data['chunk_masks']\n", - "torch_mask_pad = data['mask_pad']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(pos_emb.numpy(), torch_pos_emb))\n", - "np.testing.assert_equal(chunk_masks.numpy(), torch_chunk_masks)\n", - "np.testing.assert_equal(mask_pad.numpy(), ~torch_mask_pad)\n", - "\n", - "print(\"--------layers_______\")\n", - "i =0\n", - "for layer in model.encoder.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " i+=1\n", - "# if i == 2:\n", - "# data = np.load('.notebook/enc_2.npz')\n", - "# torch_xs = data['enc_2']\n", - "# print(np.allclose(xs.numpy(), torch_xs))\n", - "# print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n", - "# print(xs[0].numpy())\n", - "# print('xxxxxx')\n", - "# print(torch_xs[0])\n", - "# print('----i==2')\n", - "data = np.load('.notebook/enc_all.npz')\n", - "torch_xs = data['enc_all']\n", - "print(np.allclose(xs.numpy(), torch_xs))\n", - "print(np.allclose(xs.numpy(), torch_xs, atol=1e-5))\n", - "print(xs[0].numpy())\n", - "print('xxxxxx')\n", - "print(torch_xs[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "id": "municipal-stock", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 278, - "id": "macro-season", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[-0.7019424 0.5625421 0.68803453 ... 1.1237317 0.7803923\n", - " 1.1369386 ]\n", - " [-0.7787783 0.39126673 0.71887773 ... 1.251882 0.886168\n", - " 1.3173451 ]\n", - " [-0.95908964 0.6346029 0.87671334 ... 0.98183745 0.7440111\n", - " 1.2903278 ]\n", - " ...\n", - " [-1.0732255 0.67236906 0.92303115 ... 0.9075458 0.8176712\n", - " 1.3239655 ]\n", - " [-1.1654117 0.68199664 0.6939452 ... 1.2238352 0.8028294\n", - " 1.4506506 ]\n", - " [-1.2732091 0.71458054 0.7581958 ... 0.9415482 0.8774844\n", - " 1.2623048 ]]\n", - "---\n", - "[[-0.7019424 0.56254166 0.6880345 ... 1.1237322 0.78039217\n", - " 1.1369387 ]\n", - " [-0.778778 0.39126638 0.7188779 ... 1.2518823 0.8861681\n", - " 1.3173454 ]\n", - " [-0.9590891 0.6346026 0.87671363 ... 0.9818373 0.74401116\n", - " 1.2903274 ]\n", - " ...\n", - " [-1.0732253 0.6723689 0.9230311 ... 0.9075457 0.8176713\n", - " 1.3239657 ]\n", - " [-1.165412 0.6819976 0.69394535 ... 1.2238353 0.80282927\n", - " 1.4506509 ]\n", - " [-1.2732087 0.71458083 0.7581961 ... 0.9415482 0.877484\n", - " 1.2623053 ]]\n", - "False\n", - "True\n", - "False\n" - ] - } - ], - "source": [ - "encoder_out, mask = model.encoder(feat, feat_len)\n", - "print(encoder_out.numpy()[0])\n", - "print(\"---\")\n", - "print(torch_encoder_out[0])\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy()))\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-5))\n", - "print(np.allclose(torch_encoder_out, encoder_out.numpy(), atol=1e-6))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "associate-sampling", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/u2_tansformer_model_espnet.ipynb b/.notebook/u2_tansformer_model_espnet.ipynb deleted file mode 100644 index 75c2ea5c6..000000000 --- a/.notebook/u2_tansformer_model_espnet.ipynb +++ /dev/null @@ -1,1672 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "choice-grade", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/DeepSpeech-2.x'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd ..\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "broke-broad", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", - "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", - " def convert_to_list(value, n, name, dtype=np.int):\n", - "register user softmax to paddle, remove this when fixed!\n", - "register user log_softmax to paddle, remove this when fixed!\n", - "register user sigmoid to paddle, remove this when fixed!\n", - "register user log_sigmoid to paddle, remove this when fixed!\n", - "register user relu to paddle, remove this when fixed!\n", - "override cat of paddle if exists or register, remove this when fixed!\n", - "override item of paddle.Tensor if exists or register, remove this when fixed!\n", - "override long of paddle.Tensor if exists or register, remove this when fixed!\n", - "override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle.Tensor if exists or register, remove this when fixed!\n", - "override eq of paddle if exists or register, remove this when fixed!\n", - "override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", - "override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", - "register user view to paddle.Tensor, remove this when fixed!\n", - "register user view_as to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill to paddle.Tensor, remove this when fixed!\n", - "register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", - "register user fill_ to paddle.Tensor, remove this when fixed!\n", - "register user repeat to paddle.Tensor, remove this when fixed!\n", - "register user softmax to paddle.Tensor, remove this when fixed!\n", - "register user sigmoid to paddle.Tensor, remove this when fixed!\n", - "register user relu to paddle.Tensor, remove this when fixed!\n", - "register user type_as to paddle.Tensor, remove this when fixed!\n", - "register user to to paddle.Tensor, remove this when fixed!\n", - "register user float to paddle.Tensor, remove this when fixed!\n", - "register user tolist to paddle.Tensor, remove this when fixed!\n", - "register user glu to paddle.nn.functional, remove this when fixed!\n", - "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", - "register user Module to paddle.nn, remove this when fixed!\n", - "register user ModuleList to paddle.nn, remove this when fixed!\n", - "register user GLU to paddle.nn, remove this when fixed!\n", - "register user ConstantPad2d to paddle.nn, remove this when fixed!\n", - "register user export to paddle.jit, remove this when fixed!\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "import paddle\n", - "from yacs.config import CfgNode as CN\n", - "\n", - "from deepspeech.models.u2 import U2Model\n", - "from deepspeech.utils.layer_tools import print_params\n", - "from deepspeech.utils.layer_tools import summary" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "permanent-summary", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n", - "[INFO 2021/05/31 03:23:22 u2.py:839] U2 Encoder type: transformer\n", - "[INFO 2021/05/31 03:23:22 u2.py:840] attention_dropout_rate: 0.0\n", - "attention_heads: 4\n", - "dropout_rate: 0.1\n", - "input_layer: conv2d\n", - "linear_units: 2048\n", - "normalize_before: True\n", - "num_blocks: 12\n", - "output_size: 256\n", - "positional_dropout_rate: 0.1\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", - "encoder.embed.conv.0.bias | [256] | 256 | True\n", - "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", - "encoder.embed.conv.2.bias | [256] | 256 | True\n", - "encoder.embed.out.0.weight | [5120, 256] | 1310720 | True\n", - "encoder.embed.out.0.bias | [256] | 256 | True\n", - "encoder.after_norm.weight | [256] | 256 | True\n", - "encoder.after_norm.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.0.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.0.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.1.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.1.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.2.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.2.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.3.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.3.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.4.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.4.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.5.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.5.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.6.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.6.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.7.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.7.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.8.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.8.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.9.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.9.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.10.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.10.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n", - "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n", - "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm1.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm1.bias | [256] | 256 | True\n", - "encoder.encoders.11.norm2.weight | [256] | 256 | True\n", - "encoder.encoders.11.norm2.bias | [256] | 256 | True\n", - "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n", - "encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n", - "decoder.embed.0.weight | [4233, 256] | 1083648 | True\n", - "decoder.after_norm.weight | [256] | 256 | True\n", - "decoder.after_norm.bias | [256] | 256 | True\n", - "decoder.output_layer.weight | [256, 4233] | 1083648 | True\n", - "decoder.output_layer.bias | [4233] | 4233 | True\n", - "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.0.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.0.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.1.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.1.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.2.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.2.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.3.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.3.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.4.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.4.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n", - "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n", - "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm1.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm1.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm2.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm2.bias | [256] | 256 | True\n", - "decoder.decoders.5.norm3.weight | [256] | 256 | True\n", - "decoder.decoders.5.norm3.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n", - "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n", - "decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n", - "ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n", - "ctc.ctc_lo.bias | [4233] | 4233 | True\n", - "Total parameters: 411.0, 32.01M elements.\n" - ] - } - ], - "source": [ - "conf_str='examples/tiny/s1/conf/transformer.yaml'\n", - "cfg = CN().load_cfg(open(conf_str))\n", - "cfg.model.input_dim = 83\n", - "cfg.model.output_dim = 4233\n", - "cfg.model.cmvn_file = None\n", - "cfg.model.cmvn_file_type = 'json'\n", - "#cfg.model.encoder_conf.concat_after=True\n", - "cfg.freeze()\n", - "model = U2Model(cfg.model)\n", - "\n", - "print_params(model)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "sapphire-agent", - "metadata": {}, - "outputs": [], - "source": [ - "#summary(model)\n", - "#print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ruled-invitation", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "fossil-means", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n", - "encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n" - ] - } - ], - "source": [ - "%ls .notebook/espnet" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "45c2b75f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "state\n", - "odict_keys(['mask_feature', 'encoder.embed.conv.0.weight', 'encoder.embed.conv.0.bias', 'encoder.embed.conv.2.weight', 'encoder.embed.conv.2.bias', 'encoder.embed.out.0.weight', 'encoder.embed.out.0.bias', 'encoder.encoders.0.self_attn.linear_q.weight', 'encoder.encoders.0.self_attn.linear_q.bias', 'encoder.encoders.0.self_attn.linear_k.weight', 'encoder.encoders.0.self_attn.linear_k.bias', 'encoder.encoders.0.self_attn.linear_v.weight', 'encoder.encoders.0.self_attn.linear_v.bias', 'encoder.encoders.0.self_attn.linear_out.weight', 'encoder.encoders.0.self_attn.linear_out.bias', 'encoder.encoders.0.feed_forward.w_1.weight', 'encoder.encoders.0.feed_forward.w_1.bias', 'encoder.encoders.0.feed_forward.w_2.weight', 'encoder.encoders.0.feed_forward.w_2.bias', 'encoder.encoders.0.norm1.weight', 'encoder.encoders.0.norm1.bias', 'encoder.encoders.0.norm2.weight', 'encoder.encoders.0.norm2.bias', 'encoder.encoders.1.self_attn.linear_q.weight', 'encoder.encoders.1.self_attn.linear_q.bias', 'encoder.encoders.1.self_attn.linear_k.weight', 'encoder.encoders.1.self_attn.linear_k.bias', 'encoder.encoders.1.self_attn.linear_v.weight', 'encoder.encoders.1.self_attn.linear_v.bias', 'encoder.encoders.1.self_attn.linear_out.weight', 'encoder.encoders.1.self_attn.linear_out.bias', 'encoder.encoders.1.feed_forward.w_1.weight', 'encoder.encoders.1.feed_forward.w_1.bias', 'encoder.encoders.1.feed_forward.w_2.weight', 'encoder.encoders.1.feed_forward.w_2.bias', 'encoder.encoders.1.norm1.weight', 'encoder.encoders.1.norm1.bias', 'encoder.encoders.1.norm2.weight', 'encoder.encoders.1.norm2.bias', 'encoder.encoders.2.self_attn.linear_q.weight', 'encoder.encoders.2.self_attn.linear_q.bias', 'encoder.encoders.2.self_attn.linear_k.weight', 'encoder.encoders.2.self_attn.linear_k.bias', 'encoder.encoders.2.self_attn.linear_v.weight', 'encoder.encoders.2.self_attn.linear_v.bias', 'encoder.encoders.2.self_attn.linear_out.weight', 'encoder.encoders.2.self_attn.linear_out.bias', 'encoder.encoders.2.feed_forward.w_1.weight', 'encoder.encoders.2.feed_forward.w_1.bias', 'encoder.encoders.2.feed_forward.w_2.weight', 'encoder.encoders.2.feed_forward.w_2.bias', 'encoder.encoders.2.norm1.weight', 'encoder.encoders.2.norm1.bias', 'encoder.encoders.2.norm2.weight', 'encoder.encoders.2.norm2.bias', 'encoder.encoders.3.self_attn.linear_q.weight', 'encoder.encoders.3.self_attn.linear_q.bias', 'encoder.encoders.3.self_attn.linear_k.weight', 'encoder.encoders.3.self_attn.linear_k.bias', 'encoder.encoders.3.self_attn.linear_v.weight', 'encoder.encoders.3.self_attn.linear_v.bias', 'encoder.encoders.3.self_attn.linear_out.weight', 'encoder.encoders.3.self_attn.linear_out.bias', 'encoder.encoders.3.feed_forward.w_1.weight', 'encoder.encoders.3.feed_forward.w_1.bias', 'encoder.encoders.3.feed_forward.w_2.weight', 'encoder.encoders.3.feed_forward.w_2.bias', 'encoder.encoders.3.norm1.weight', 'encoder.encoders.3.norm1.bias', 'encoder.encoders.3.norm2.weight', 'encoder.encoders.3.norm2.bias', 'encoder.encoders.4.self_attn.linear_q.weight', 'encoder.encoders.4.self_attn.linear_q.bias', 'encoder.encoders.4.self_attn.linear_k.weight', 'encoder.encoders.4.self_attn.linear_k.bias', 'encoder.encoders.4.self_attn.linear_v.weight', 'encoder.encoders.4.self_attn.linear_v.bias', 'encoder.encoders.4.self_attn.linear_out.weight', 'encoder.encoders.4.self_attn.linear_out.bias', 'encoder.encoders.4.feed_forward.w_1.weight', 'encoder.encoders.4.feed_forward.w_1.bias', 'encoder.encoders.4.feed_forward.w_2.weight', 'encoder.encoders.4.feed_forward.w_2.bias', 'encoder.encoders.4.norm1.weight', 'encoder.encoders.4.norm1.bias', 'encoder.encoders.4.norm2.weight', 'encoder.encoders.4.norm2.bias', 'encoder.encoders.5.self_attn.linear_q.weight', 'encoder.encoders.5.self_attn.linear_q.bias', 'encoder.encoders.5.self_attn.linear_k.weight', 'encoder.encoders.5.self_attn.linear_k.bias', 'encoder.encoders.5.self_attn.linear_v.weight', 'encoder.encoders.5.self_attn.linear_v.bias', 'encoder.encoders.5.self_attn.linear_out.weight', 'encoder.encoders.5.self_attn.linear_out.bias', 'encoder.encoders.5.feed_forward.w_1.weight', 'encoder.encoders.5.feed_forward.w_1.bias', 'encoder.encoders.5.feed_forward.w_2.weight', 'encoder.encoders.5.feed_forward.w_2.bias', 'encoder.encoders.5.norm1.weight', 'encoder.encoders.5.norm1.bias', 'encoder.encoders.5.norm2.weight', 'encoder.encoders.5.norm2.bias', 'encoder.encoders.6.self_attn.linear_q.weight', 'encoder.encoders.6.self_attn.linear_q.bias', 'encoder.encoders.6.self_attn.linear_k.weight', 'encoder.encoders.6.self_attn.linear_k.bias', 'encoder.encoders.6.self_attn.linear_v.weight', 'encoder.encoders.6.self_attn.linear_v.bias', 'encoder.encoders.6.self_attn.linear_out.weight', 'encoder.encoders.6.self_attn.linear_out.bias', 'encoder.encoders.6.feed_forward.w_1.weight', 'encoder.encoders.6.feed_forward.w_1.bias', 'encoder.encoders.6.feed_forward.w_2.weight', 'encoder.encoders.6.feed_forward.w_2.bias', 'encoder.encoders.6.norm1.weight', 'encoder.encoders.6.norm1.bias', 'encoder.encoders.6.norm2.weight', 'encoder.encoders.6.norm2.bias', 'encoder.encoders.7.self_attn.linear_q.weight', 'encoder.encoders.7.self_attn.linear_q.bias', 'encoder.encoders.7.self_attn.linear_k.weight', 'encoder.encoders.7.self_attn.linear_k.bias', 'encoder.encoders.7.self_attn.linear_v.weight', 'encoder.encoders.7.self_attn.linear_v.bias', 'encoder.encoders.7.self_attn.linear_out.weight', 'encoder.encoders.7.self_attn.linear_out.bias', 'encoder.encoders.7.feed_forward.w_1.weight', 'encoder.encoders.7.feed_forward.w_1.bias', 'encoder.encoders.7.feed_forward.w_2.weight', 'encoder.encoders.7.feed_forward.w_2.bias', 'encoder.encoders.7.norm1.weight', 'encoder.encoders.7.norm1.bias', 'encoder.encoders.7.norm2.weight', 'encoder.encoders.7.norm2.bias', 'encoder.encoders.8.self_attn.linear_q.weight', 'encoder.encoders.8.self_attn.linear_q.bias', 'encoder.encoders.8.self_attn.linear_k.weight', 'encoder.encoders.8.self_attn.linear_k.bias', 'encoder.encoders.8.self_attn.linear_v.weight', 'encoder.encoders.8.self_attn.linear_v.bias', 'encoder.encoders.8.self_attn.linear_out.weight', 'encoder.encoders.8.self_attn.linear_out.bias', 'encoder.encoders.8.feed_forward.w_1.weight', 'encoder.encoders.8.feed_forward.w_1.bias', 'encoder.encoders.8.feed_forward.w_2.weight', 'encoder.encoders.8.feed_forward.w_2.bias', 'encoder.encoders.8.norm1.weight', 'encoder.encoders.8.norm1.bias', 'encoder.encoders.8.norm2.weight', 'encoder.encoders.8.norm2.bias', 'encoder.encoders.9.self_attn.linear_q.weight', 'encoder.encoders.9.self_attn.linear_q.bias', 'encoder.encoders.9.self_attn.linear_k.weight', 'encoder.encoders.9.self_attn.linear_k.bias', 'encoder.encoders.9.self_attn.linear_v.weight', 'encoder.encoders.9.self_attn.linear_v.bias', 'encoder.encoders.9.self_attn.linear_out.weight', 'encoder.encoders.9.self_attn.linear_out.bias', 'encoder.encoders.9.feed_forward.w_1.weight', 'encoder.encoders.9.feed_forward.w_1.bias', 'encoder.encoders.9.feed_forward.w_2.weight', 'encoder.encoders.9.feed_forward.w_2.bias', 'encoder.encoders.9.norm1.weight', 'encoder.encoders.9.norm1.bias', 'encoder.encoders.9.norm2.weight', 'encoder.encoders.9.norm2.bias', 'encoder.encoders.10.self_attn.linear_q.weight', 'encoder.encoders.10.self_attn.linear_q.bias', 'encoder.encoders.10.self_attn.linear_k.weight', 'encoder.encoders.10.self_attn.linear_k.bias', 'encoder.encoders.10.self_attn.linear_v.weight', 'encoder.encoders.10.self_attn.linear_v.bias', 'encoder.encoders.10.self_attn.linear_out.weight', 'encoder.encoders.10.self_attn.linear_out.bias', 'encoder.encoders.10.feed_forward.w_1.weight', 'encoder.encoders.10.feed_forward.w_1.bias', 'encoder.encoders.10.feed_forward.w_2.weight', 'encoder.encoders.10.feed_forward.w_2.bias', 'encoder.encoders.10.norm1.weight', 'encoder.encoders.10.norm1.bias', 'encoder.encoders.10.norm2.weight', 'encoder.encoders.10.norm2.bias', 'encoder.encoders.11.self_attn.linear_q.weight', 'encoder.encoders.11.self_attn.linear_q.bias', 'encoder.encoders.11.self_attn.linear_k.weight', 'encoder.encoders.11.self_attn.linear_k.bias', 'encoder.encoders.11.self_attn.linear_v.weight', 'encoder.encoders.11.self_attn.linear_v.bias', 'encoder.encoders.11.self_attn.linear_out.weight', 'encoder.encoders.11.self_attn.linear_out.bias', 'encoder.encoders.11.feed_forward.w_1.weight', 'encoder.encoders.11.feed_forward.w_1.bias', 'encoder.encoders.11.feed_forward.w_2.weight', 'encoder.encoders.11.feed_forward.w_2.bias', 'encoder.encoders.11.norm1.weight', 'encoder.encoders.11.norm1.bias', 'encoder.encoders.11.norm2.weight', 'encoder.encoders.11.norm2.bias', 'encoder.after_norm.weight', 'encoder.after_norm.bias', 'decoder.embed.0.weight', 'decoder.decoders.0.self_attn.linear_q.weight', 'decoder.decoders.0.self_attn.linear_q.bias', 'decoder.decoders.0.self_attn.linear_k.weight', 'decoder.decoders.0.self_attn.linear_k.bias', 'decoder.decoders.0.self_attn.linear_v.weight', 'decoder.decoders.0.self_attn.linear_v.bias', 'decoder.decoders.0.self_attn.linear_out.weight', 'decoder.decoders.0.self_attn.linear_out.bias', 'decoder.decoders.0.src_attn.linear_q.weight', 'decoder.decoders.0.src_attn.linear_q.bias', 'decoder.decoders.0.src_attn.linear_k.weight', 'decoder.decoders.0.src_attn.linear_k.bias', 'decoder.decoders.0.src_attn.linear_v.weight', 'decoder.decoders.0.src_attn.linear_v.bias', 'decoder.decoders.0.src_attn.linear_out.weight', 'decoder.decoders.0.src_attn.linear_out.bias', 'decoder.decoders.0.feed_forward.w_1.weight', 'decoder.decoders.0.feed_forward.w_1.bias', 'decoder.decoders.0.feed_forward.w_2.weight', 'decoder.decoders.0.feed_forward.w_2.bias', 'decoder.decoders.0.norm1.weight', 'decoder.decoders.0.norm1.bias', 'decoder.decoders.0.norm2.weight', 'decoder.decoders.0.norm2.bias', 'decoder.decoders.0.norm3.weight', 'decoder.decoders.0.norm3.bias', 'decoder.after_norm.weight', 'decoder.after_norm.bias', 'decoder.output_layer.weight', 'decoder.output_layer.bias', 'sfc.weight', 'sfc.bias', 'deconv.0.weight', 'deconv.0.bias', 'deconv.1.weight', 'deconv.1.bias', 'xlm_embed.0.weight', 'xlm_pred.weight', 'xlm_pred.bias'])\n" - ] - } - ], - "source": [ - "#!pip install torch\n", - "import torch\n", - "\n", - "e_model = np.load('.notebook/espnet/model.npz',allow_pickle=True)\n", - "for k in e_model.files:\n", - " print(k)\n", - "state_dict = e_model['state']\n", - "state_dict = state_dict.tolist()\n", - "print(state_dict.keys())" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f187bb55", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", - " and should_run_async(code)\n" - ] - } - ], - "source": [ - "# embed.conv.0.weight None torch.Size([256, 1, 3, 3]) \tencoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", - "# embed.conv.0.bias None torch.Size([256]) \tencoder.embed.conv.0.bias | [256] | 256 | True\n", - "# embed.conv.2.weight None torch.Size([256, 256, 3, 3]) \tencoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", - "# embed.conv.2.bias None torch.Size([256]) \tencoder.embed.conv.2.bias | [256] | 256 | True\n", - "# embed.out.0.weight None torch.Size([256, 5120]) 83 feature\tencoder.embed.out.0.weight | [4864, 256] | 1245184 | True 80 feature\n", - "# embed.out.0.bias None torch.Size([256]) \tencoder.embed.out.0.bias | [256] | 256 | True\n", - "# after_norm.weight None torch.Size([256]) \tencoder.after_norm.weight | [256] | 256 | True\n", - "# after_norm.bias None torch.Size([256]) \tencoder.after_norm.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_q.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_q.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_k.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_k.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_v.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_v.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", - "# encoders.9.self_attn.linear_out.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", - "# encoders.9.self_attn.linear_out.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", - "# encoders.9.feed_forward.w_1.weight None torch.Size([2048, 256]) \tencoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", - "# encoders.9.feed_forward.w_1.bias None torch.Size([2048]) \tencoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", - "# encoders.9.feed_forward.w_2.weight None torch.Size([256, 2048]) \tencoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", - "# encoders.9.feed_forward.w_2.bias None torch.Size([256]) \tencoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", - "# encoders.9.norm1.weight None torch.Size([256]) \tencoder.encoders.0.norm1.weight | [256] | 256 | True\n", - "# encoders.9.norm1.bias None torch.Size([256]) \tencoder.encoders.0.norm1.bias | [256] | 256 | True\n", - "# encoders.9.norm2.weight None torch.Size([256]) \tencoder.encoders.0.norm2.weight | [256] | 256 | True\n", - "# encoders.9.norm2.bias None torch.Size([256]) \tencoder.encoders.0.norm2.bias | [256] | 256 | True\n", - "# \tencoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", - "# \tencoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", - "# espnet transformer\tconcat_linear只是保存了,但是未使用\n", - "\t\n", - "# \tpaddle transformer" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "2a0428ae", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-> encoder.embed.conv.0.weight\n", - "-> encoder.embed.conv.0.bias\n", - "-> encoder.embed.conv.2.weight\n", - "-> encoder.embed.conv.2.bias\n", - "-> encoder.embed.out.0.weight\n", - "encoder.embed.out.0.weight: (256, 5120) -> (5120, 256)\n", - "-> encoder.embed.out.0.bias\n", - "-> encoder.encoders.0.self_attn.linear_q.weight\n", - "encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_q.bias\n", - "-> encoder.encoders.0.self_attn.linear_k.weight\n", - "encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_k.bias\n", - "-> encoder.encoders.0.self_attn.linear_v.weight\n", - "encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_v.bias\n", - "-> encoder.encoders.0.self_attn.linear_out.weight\n", - "encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.0.self_attn.linear_out.bias\n", - "-> encoder.encoders.0.feed_forward.w_1.weight\n", - "encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.0.feed_forward.w_1.bias\n", - "-> encoder.encoders.0.feed_forward.w_2.weight\n", - "encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.0.feed_forward.w_2.bias\n", - "-> encoder.encoders.0.norm1.weight\n", - "-> encoder.encoders.0.norm1.bias\n", - "-> encoder.encoders.0.norm2.weight\n", - "-> encoder.encoders.0.norm2.bias\n", - "-> encoder.encoders.1.self_attn.linear_q.weight\n", - "encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_q.bias\n", - "-> encoder.encoders.1.self_attn.linear_k.weight\n", - "encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_k.bias\n", - "-> encoder.encoders.1.self_attn.linear_v.weight\n", - "encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_v.bias\n", - "-> encoder.encoders.1.self_attn.linear_out.weight\n", - "encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.1.self_attn.linear_out.bias\n", - "-> encoder.encoders.1.feed_forward.w_1.weight\n", - "encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.1.feed_forward.w_1.bias\n", - "-> encoder.encoders.1.feed_forward.w_2.weight\n", - "encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.1.feed_forward.w_2.bias\n", - "-> encoder.encoders.1.norm1.weight\n", - "-> encoder.encoders.1.norm1.bias\n", - "-> encoder.encoders.1.norm2.weight\n", - "-> encoder.encoders.1.norm2.bias\n", - "-> encoder.encoders.2.self_attn.linear_q.weight\n", - "encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_q.bias\n", - "-> encoder.encoders.2.self_attn.linear_k.weight\n", - "encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_k.bias\n", - "-> encoder.encoders.2.self_attn.linear_v.weight\n", - "encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_v.bias\n", - "-> encoder.encoders.2.self_attn.linear_out.weight\n", - "encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.2.self_attn.linear_out.bias\n", - "-> encoder.encoders.2.feed_forward.w_1.weight\n", - "encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.2.feed_forward.w_1.bias\n", - "-> encoder.encoders.2.feed_forward.w_2.weight\n", - "encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.2.feed_forward.w_2.bias\n", - "-> encoder.encoders.2.norm1.weight\n", - "-> encoder.encoders.2.norm1.bias\n", - "-> encoder.encoders.2.norm2.weight\n", - "-> encoder.encoders.2.norm2.bias\n", - "-> encoder.encoders.3.self_attn.linear_q.weight\n", - "encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_q.bias\n", - "-> encoder.encoders.3.self_attn.linear_k.weight\n", - "encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_k.bias\n", - "-> encoder.encoders.3.self_attn.linear_v.weight\n", - "encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_v.bias\n", - "-> encoder.encoders.3.self_attn.linear_out.weight\n", - "encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.3.self_attn.linear_out.bias\n", - "-> encoder.encoders.3.feed_forward.w_1.weight\n", - "encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.3.feed_forward.w_1.bias\n", - "-> encoder.encoders.3.feed_forward.w_2.weight\n", - "encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.3.feed_forward.w_2.bias\n", - "-> encoder.encoders.3.norm1.weight\n", - "-> encoder.encoders.3.norm1.bias\n", - "-> encoder.encoders.3.norm2.weight\n", - "-> encoder.encoders.3.norm2.bias\n", - "-> encoder.encoders.4.self_attn.linear_q.weight\n", - "encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_q.bias\n", - "-> encoder.encoders.4.self_attn.linear_k.weight\n", - "encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_k.bias\n", - "-> encoder.encoders.4.self_attn.linear_v.weight\n", - "encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_v.bias\n", - "-> encoder.encoders.4.self_attn.linear_out.weight\n", - "encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.4.self_attn.linear_out.bias\n", - "-> encoder.encoders.4.feed_forward.w_1.weight\n", - "encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.4.feed_forward.w_1.bias\n", - "-> encoder.encoders.4.feed_forward.w_2.weight\n", - "encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.4.feed_forward.w_2.bias\n", - "-> encoder.encoders.4.norm1.weight\n", - "-> encoder.encoders.4.norm1.bias\n", - "-> encoder.encoders.4.norm2.weight\n", - "-> encoder.encoders.4.norm2.bias\n", - "-> encoder.encoders.5.self_attn.linear_q.weight\n", - "encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_q.bias\n", - "-> encoder.encoders.5.self_attn.linear_k.weight\n", - "encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_k.bias\n", - "-> encoder.encoders.5.self_attn.linear_v.weight\n", - "encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_v.bias\n", - "-> encoder.encoders.5.self_attn.linear_out.weight\n", - "encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.5.self_attn.linear_out.bias\n", - "-> encoder.encoders.5.feed_forward.w_1.weight\n", - "encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.5.feed_forward.w_1.bias\n", - "-> encoder.encoders.5.feed_forward.w_2.weight\n", - "encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.5.feed_forward.w_2.bias\n", - "-> encoder.encoders.5.norm1.weight\n", - "-> encoder.encoders.5.norm1.bias\n", - "-> encoder.encoders.5.norm2.weight\n", - "-> encoder.encoders.5.norm2.bias\n", - "-> encoder.encoders.6.self_attn.linear_q.weight\n", - "encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_q.bias\n", - "-> encoder.encoders.6.self_attn.linear_k.weight\n", - "encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_k.bias\n", - "-> encoder.encoders.6.self_attn.linear_v.weight\n", - "encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_v.bias\n", - "-> encoder.encoders.6.self_attn.linear_out.weight\n", - "encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.6.self_attn.linear_out.bias\n", - "-> encoder.encoders.6.feed_forward.w_1.weight\n", - "encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.6.feed_forward.w_1.bias\n", - "-> encoder.encoders.6.feed_forward.w_2.weight\n", - "encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.6.feed_forward.w_2.bias\n", - "-> encoder.encoders.6.norm1.weight\n", - "-> encoder.encoders.6.norm1.bias\n", - "-> encoder.encoders.6.norm2.weight\n", - "-> encoder.encoders.6.norm2.bias\n", - "-> encoder.encoders.7.self_attn.linear_q.weight\n", - "encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_q.bias\n", - "-> encoder.encoders.7.self_attn.linear_k.weight\n", - "encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_k.bias\n", - "-> encoder.encoders.7.self_attn.linear_v.weight\n", - "encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_v.bias\n", - "-> encoder.encoders.7.self_attn.linear_out.weight\n", - "encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.7.self_attn.linear_out.bias\n", - "-> encoder.encoders.7.feed_forward.w_1.weight\n", - "encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.7.feed_forward.w_1.bias\n", - "-> encoder.encoders.7.feed_forward.w_2.weight\n", - "encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.7.feed_forward.w_2.bias\n", - "-> encoder.encoders.7.norm1.weight\n", - "-> encoder.encoders.7.norm1.bias\n", - "-> encoder.encoders.7.norm2.weight\n", - "-> encoder.encoders.7.norm2.bias\n", - "-> encoder.encoders.8.self_attn.linear_q.weight\n", - "encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_q.bias\n", - "-> encoder.encoders.8.self_attn.linear_k.weight\n", - "encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_k.bias\n", - "-> encoder.encoders.8.self_attn.linear_v.weight\n", - "encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_v.bias\n", - "-> encoder.encoders.8.self_attn.linear_out.weight\n", - "encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.8.self_attn.linear_out.bias\n", - "-> encoder.encoders.8.feed_forward.w_1.weight\n", - "encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.8.feed_forward.w_1.bias\n", - "-> encoder.encoders.8.feed_forward.w_2.weight\n", - "encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.8.feed_forward.w_2.bias\n", - "-> encoder.encoders.8.norm1.weight\n", - "-> encoder.encoders.8.norm1.bias\n", - "-> encoder.encoders.8.norm2.weight\n", - "-> encoder.encoders.8.norm2.bias\n", - "-> encoder.encoders.9.self_attn.linear_q.weight\n", - "encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_q.bias\n", - "-> encoder.encoders.9.self_attn.linear_k.weight\n", - "encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_k.bias\n", - "-> encoder.encoders.9.self_attn.linear_v.weight\n", - "encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_v.bias\n", - "-> encoder.encoders.9.self_attn.linear_out.weight\n", - "encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.9.self_attn.linear_out.bias\n", - "-> encoder.encoders.9.feed_forward.w_1.weight\n", - "encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.9.feed_forward.w_1.bias\n", - "-> encoder.encoders.9.feed_forward.w_2.weight\n", - "encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.9.feed_forward.w_2.bias\n", - "-> encoder.encoders.9.norm1.weight\n", - "-> encoder.encoders.9.norm1.bias\n", - "-> encoder.encoders.9.norm2.weight\n", - "-> encoder.encoders.9.norm2.bias\n", - "-> encoder.encoders.10.self_attn.linear_q.weight\n", - "encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_q.bias\n", - "-> encoder.encoders.10.self_attn.linear_k.weight\n", - "encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_k.bias\n", - "-> encoder.encoders.10.self_attn.linear_v.weight\n", - "encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_v.bias\n", - "-> encoder.encoders.10.self_attn.linear_out.weight\n", - "encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.10.self_attn.linear_out.bias\n", - "-> encoder.encoders.10.feed_forward.w_1.weight\n", - "encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.10.feed_forward.w_1.bias\n", - "-> encoder.encoders.10.feed_forward.w_2.weight\n", - "encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.10.feed_forward.w_2.bias\n", - "-> encoder.encoders.10.norm1.weight\n", - "-> encoder.encoders.10.norm1.bias\n", - "-> encoder.encoders.10.norm2.weight\n", - "-> encoder.encoders.10.norm2.bias\n", - "-> encoder.encoders.11.self_attn.linear_q.weight\n", - "encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_q.bias\n", - "-> encoder.encoders.11.self_attn.linear_k.weight\n", - "encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_k.bias\n", - "-> encoder.encoders.11.self_attn.linear_v.weight\n", - "encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_v.bias\n", - "-> encoder.encoders.11.self_attn.linear_out.weight\n", - "encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "-> encoder.encoders.11.self_attn.linear_out.bias\n", - "-> encoder.encoders.11.feed_forward.w_1.weight\n", - "encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "-> encoder.encoders.11.feed_forward.w_1.bias\n", - "-> encoder.encoders.11.feed_forward.w_2.weight\n", - "encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "-> encoder.encoders.11.feed_forward.w_2.bias\n", - "-> encoder.encoders.11.norm1.weight\n", - "-> encoder.encoders.11.norm1.bias\n", - "-> encoder.encoders.11.norm2.weight\n", - "-> encoder.encoders.11.norm2.bias\n", - "-> encoder.after_norm.weight\n", - "-> encoder.after_norm.bias\n" - ] - } - ], - "source": [ - "# dump torch model to paddle\n", - "#state_dict = model.state_dict()\n", - "paddle_state_dict = {}\n", - "\n", - "for n, p in state_dict.items():\n", - " if 'encoder' not in n:\n", - " continue \n", - " print(f'-> {n}')\n", - " \n", - " \n", - " name_change=True\n", - " if 'norm.running_mean' in n:\n", - " new_n = n.replace('norm.running_', 'norm._')\n", - " elif 'norm.running_var' in n:\n", - " new_n = n.replace('norm.running_var', 'norm._variance')\n", - " else:\n", - " name_change=False\n", - " new_n = n\n", - " if name_change:\n", - " print(f\"{n} -> {new_n}\")\n", - " \n", - " \n", - " p = p.cpu().detach().numpy()\n", - " if n.endswith('weight') and p.ndim == 2:\n", - " new_p = p.T\n", - " print(f\"{n}: {p.shape} -> {new_p.shape}\")\n", - " else:\n", - " new_p = p\n", - " \n", - " if 'global_cmvn.mean' in n:\n", - " print(p, p.dtype)\n", - " \n", - " paddle_state_dict[new_n] = new_p\n", - " \n", - "# np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n", - "# state=paddle_state_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "a1d97e9f", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.weight. encoder.encoders.0.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.bias. encoder.encoders.0.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.weight. encoder.encoders.1.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.bias. encoder.encoders.1.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.weight. encoder.encoders.2.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.bias. encoder.encoders.2.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.weight. encoder.encoders.3.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.bias. encoder.encoders.3.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.weight. encoder.encoders.4.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.bias. encoder.encoders.4.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.weight. encoder.encoders.5.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.bias. encoder.encoders.5.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.weight. encoder.encoders.6.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.bias. encoder.encoders.6.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.weight. encoder.encoders.7.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.bias. encoder.encoders.7.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.weight. encoder.encoders.8.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.bias. encoder.encoders.8.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.weight. encoder.encoders.9.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.bias. encoder.encoders.9.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.weight. encoder.encoders.10.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.bias. encoder.encoders.10.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.weight. encoder.encoders.11.concat_linear.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.bias. encoder.encoders.11.concat_linear.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.embed.0.weight. decoder.embed.0.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.weight. decoder.after_norm.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.bias. decoder.after_norm.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.weight. decoder.output_layer.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.bias. decoder.output_layer.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.weight. decoder.decoders.0.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.bias. decoder.decoders.0.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.weight. decoder.decoders.0.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.bias. decoder.decoders.0.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.weight. decoder.decoders.0.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.bias. decoder.decoders.0.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.weight. decoder.decoders.0.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.bias. decoder.decoders.0.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.weight. decoder.decoders.0.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.bias. decoder.decoders.0.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.weight. decoder.decoders.0.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.bias. decoder.decoders.0.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.weight. decoder.decoders.0.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.bias. decoder.decoders.0.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.weight. decoder.decoders.0.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.bias. decoder.decoders.0.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.weight. decoder.decoders.0.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.bias. decoder.decoders.0.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.weight. decoder.decoders.0.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.bias. decoder.decoders.0.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.weight. decoder.decoders.0.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.bias. decoder.decoders.0.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.weight. decoder.decoders.0.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.bias. decoder.decoders.0.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.weight. decoder.decoders.0.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.bias. decoder.decoders.0.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.weight. decoder.decoders.0.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.bias. decoder.decoders.0.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.weight. decoder.decoders.0.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.bias. decoder.decoders.0.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.weight. decoder.decoders.1.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.bias. decoder.decoders.1.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.weight. decoder.decoders.1.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.bias. decoder.decoders.1.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.weight. decoder.decoders.1.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.bias. decoder.decoders.1.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.weight. decoder.decoders.1.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.bias. decoder.decoders.1.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.weight. decoder.decoders.1.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.bias. decoder.decoders.1.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.weight. decoder.decoders.1.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.bias. decoder.decoders.1.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.weight. decoder.decoders.1.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.bias. decoder.decoders.1.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.weight. decoder.decoders.1.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.bias. decoder.decoders.1.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.weight. decoder.decoders.1.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.bias. decoder.decoders.1.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.weight. decoder.decoders.1.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.bias. decoder.decoders.1.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.weight. decoder.decoders.1.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.bias. decoder.decoders.1.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.weight. decoder.decoders.1.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.bias. decoder.decoders.1.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.weight. decoder.decoders.1.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.bias. decoder.decoders.1.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.weight. decoder.decoders.1.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.bias. decoder.decoders.1.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.weight. decoder.decoders.1.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.bias. decoder.decoders.1.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.weight. decoder.decoders.2.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.bias. decoder.decoders.2.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.weight. decoder.decoders.2.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.bias. decoder.decoders.2.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.weight. decoder.decoders.2.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.bias. decoder.decoders.2.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.weight. decoder.decoders.2.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.bias. decoder.decoders.2.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.weight. decoder.decoders.2.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.bias. decoder.decoders.2.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.weight. decoder.decoders.2.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.bias. decoder.decoders.2.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.weight. decoder.decoders.2.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.bias. decoder.decoders.2.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.weight. decoder.decoders.2.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.bias. decoder.decoders.2.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.weight. decoder.decoders.2.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.bias. decoder.decoders.2.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.weight. decoder.decoders.2.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.bias. decoder.decoders.2.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.weight. decoder.decoders.2.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.bias. decoder.decoders.2.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.weight. decoder.decoders.2.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.bias. decoder.decoders.2.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.weight. decoder.decoders.2.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.bias. decoder.decoders.2.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.weight. decoder.decoders.2.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.bias. decoder.decoders.2.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.weight. decoder.decoders.2.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.bias. decoder.decoders.2.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.weight. decoder.decoders.3.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.bias. decoder.decoders.3.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.weight. decoder.decoders.3.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.bias. decoder.decoders.3.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.weight. decoder.decoders.3.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.bias. decoder.decoders.3.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.weight. decoder.decoders.3.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.bias. decoder.decoders.3.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.weight. decoder.decoders.3.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.bias. decoder.decoders.3.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.weight. decoder.decoders.3.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.bias. decoder.decoders.3.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.weight. decoder.decoders.3.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.bias. decoder.decoders.3.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.weight. decoder.decoders.3.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.bias. decoder.decoders.3.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.weight. decoder.decoders.3.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.bias. decoder.decoders.3.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.weight. decoder.decoders.3.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.bias. decoder.decoders.3.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.weight. decoder.decoders.3.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.bias. decoder.decoders.3.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.weight. decoder.decoders.3.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.bias. decoder.decoders.3.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.weight. decoder.decoders.3.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.bias. decoder.decoders.3.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.weight. decoder.decoders.3.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.bias. decoder.decoders.3.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.weight. decoder.decoders.3.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.bias. decoder.decoders.3.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.weight. decoder.decoders.4.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.bias. decoder.decoders.4.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.weight. decoder.decoders.4.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.bias. decoder.decoders.4.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.weight. decoder.decoders.4.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.bias. decoder.decoders.4.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.weight. decoder.decoders.4.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.bias. decoder.decoders.4.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.weight. decoder.decoders.4.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.bias. decoder.decoders.4.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.weight. decoder.decoders.4.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.bias. decoder.decoders.4.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.weight. decoder.decoders.4.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.bias. decoder.decoders.4.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.weight. decoder.decoders.4.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.bias. decoder.decoders.4.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.weight. decoder.decoders.4.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.bias. decoder.decoders.4.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.weight. decoder.decoders.4.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.bias. decoder.decoders.4.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.weight. decoder.decoders.4.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.bias. decoder.decoders.4.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.weight. decoder.decoders.4.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.bias. decoder.decoders.4.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.weight. decoder.decoders.4.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.bias. decoder.decoders.4.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.weight. decoder.decoders.4.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.bias. decoder.decoders.4.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.weight. decoder.decoders.4.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.bias. decoder.decoders.4.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.weight. decoder.decoders.5.self_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.bias. decoder.decoders.5.self_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.weight. decoder.decoders.5.self_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.bias. decoder.decoders.5.self_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.weight. decoder.decoders.5.self_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.bias. decoder.decoders.5.self_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.weight. decoder.decoders.5.self_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.bias. decoder.decoders.5.self_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.weight. decoder.decoders.5.src_attn.linear_q.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.bias. decoder.decoders.5.src_attn.linear_q.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.weight. decoder.decoders.5.src_attn.linear_k.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.bias. decoder.decoders.5.src_attn.linear_k.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.weight. decoder.decoders.5.src_attn.linear_v.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.bias. decoder.decoders.5.src_attn.linear_v.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.weight. decoder.decoders.5.src_attn.linear_out.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.bias. decoder.decoders.5.src_attn.linear_out.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.weight. decoder.decoders.5.feed_forward.w_1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.bias. decoder.decoders.5.feed_forward.w_1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.weight. decoder.decoders.5.feed_forward.w_2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.bias. decoder.decoders.5.feed_forward.w_2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.weight. decoder.decoders.5.norm1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.bias. decoder.decoders.5.norm1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.weight. decoder.decoders.5.norm2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.bias. decoder.decoders.5.norm2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.weight. decoder.decoders.5.norm3.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.bias. decoder.decoders.5.norm3.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.weight. decoder.decoders.5.concat_linear1.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.bias. decoder.decoders.5.concat_linear1.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.weight. decoder.decoders.5.concat_linear2.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.bias. decoder.decoders.5.concat_linear2.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.weight. ctc.ctc_lo.weight is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", - "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.bias. ctc.ctc_lo.bias is not found in the provided dict.\n", - " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n" - ] - } - ], - "source": [ - "model.set_state_dict(paddle_state_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "fc7edf1e", - "metadata": {}, - "outputs": [], - "source": [ - "e_state = model.encoder.state_dict()\n", - "for key, value in e_state.items():\n", - " if 'concat_linear' in key:\n", - " continue\n", - " if not np.allclose(value.numpy(), paddle_state_dict['encoder.' + key]):\n", - " print(key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "572097d0", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "748250b7", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "91e5deee", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "fleet-despite", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n", - "encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n" - ] - } - ], - "source": [ - "%ls .notebook/espnet" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "abroad-oracle", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(8, 57, 83)\n", - "(8, 1, 57)\n", - "[57 50 48 38 32 31 28 25]\n" - ] - } - ], - "source": [ - "data = np.load('.notebook/espnet/feat.npz', allow_pickle=True)\n", - "xs=data['xs']\n", - "masks=data['masks']\n", - "print(xs.shape)\n", - "print(masks.shape)\n", - "xs_lens = masks.sum(axis=-1).squeeze()\n", - "print(xs_lens)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "false-instrument", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[8, 13, 256]\n", - "[8, 1, 13]\n" - ] - } - ], - "source": [ - "# ecnoder\n", - "xs = paddle.to_tensor(xs, dtype='float32')\n", - "x_lens = paddle.to_tensor(xs_lens, dtype='int32')\n", - "model.eval()\n", - "encoder_out, encoder_mask = model.encoder(xs, x_lens)\n", - "print(encoder_out.shape)\n", - "print(encoder_mask.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "arctic-proxy", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(8, 13, 256)\n", - "(8, 1, 13)\n", - "False\n", - "False\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "data = np.load('.notebook/espnet/encoder.npz', allow_pickle=True)\n", - "xs = data['xs']\n", - "masks = data['masks']\n", - "print(xs.shape)\n", - "print(masks.shape)\n", - "print(np.allclose(xs, encoder_out.numpy()))\n", - "print(np.allclose(xs, encoder_out.numpy(), atol=1e-6))\n", - "print(np.allclose(xs, encoder_out.numpy(), atol=1e-5))\n", - "print(np.allclose(masks, encoder_mask.numpy()))" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "seasonal-switch", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 2.1380312 1.8675405 -1.1873871 ... -0.30456656 0.56382173\n", - " -0.6526459 ]\n", - " [ 2.1926146 2.1373641 -0.6548196 ... -0.897318 0.6044322\n", - " -0.63332295]\n", - " [ 1.6367635 2.3320658 -0.8848577 ... -0.9640939 1.2420733\n", - " -0.05243584]\n", - " ...\n", - " [ 1.8533031 1.8421621 -0.6728406 ... 0.04810616 0.6459763\n", - " -0.18188554]\n", - " [ 2.0894065 1.7813934 -1.1591585 ... -0.09513803 0.8321831\n", - " -0.72916794]\n", - " [ 1.6488649 2.0984242 -1.3490562 ... 0.42678255 0.5903866\n", - " -0.32597935]]\n", - "Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[ 2.13803196, 1.86753929, -1.18738675, ..., -0.30456796, 0.56382364, -0.65264463],\n", - " [ 2.19261336, 2.13736486, -0.65482187, ..., -0.89731705, 0.60443199, -0.63332343],\n", - " [ 1.63676369, 2.33206534, -0.88485885, ..., -0.96409231, 1.24207270, -0.05243752],\n", - " ...,\n", - " [ 1.85330284, 1.84216177, -0.67284071, ..., 0.04810715, 0.64597648, -0.18188696],\n", - " [ 2.08940673, 1.78139246, -1.15916038, ..., -0.09513779, 0.83218288, -0.72916913],\n", - " [ 1.64886570, 2.09842515, -1.34905660, ..., 0.42678308, 0.59038705, -0.32598034]])\n" - ] - } - ], - "source": [ - "print(xs[0])\n", - "print(encoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "defined-brooks", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 2.209824 1.5208759 0.1417884 ... -0.73617566 1.6538682\n", - " -0.16355833]\n", - " [ 2.1441019 1.4377339 0.3629197 ... -0.91226125 1.3739952\n", - " 0.11874156]\n", - " [ 1.8725398 1.5417286 0.38919652 ... -0.89621615 1.1841662\n", - " 0.27621832]\n", - " ...\n", - " [ 2.4591084 0.7238764 -1.1456345 ... -0.24188249 0.8232168\n", - " -0.9794884 ]\n", - " [ 2.5156236 1.1919155 -0.97032744 ... -0.7360675 1.0647209\n", - " -1.3076135 ]\n", - " [ 2.160009 0.98425585 -1.2231126 ... -0.03393313 1.9141548\n", - " -1.0099151 ]]\n", - "Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", - " [[ 2.20982409, 1.52087593, 0.14178854, ..., -0.73617446, 1.65386844, -0.16355731],\n", - " [ 2.14410043, 1.43773460, 0.36291891, ..., -0.91226172, 1.37399518, 0.11874183],\n", - " [ 1.87254059, 1.54172909, 0.38919681, ..., -0.89621687, 1.18416822, 0.27621880],\n", - " ...,\n", - " [ 2.45910931, 0.72387671, -1.14563596, ..., -0.24188218, 0.82321703, -0.97948682],\n", - " [ 2.51562238, 1.19191694, -0.97032893, ..., -0.73606837, 1.06472087, -1.30761123],\n", - " [ 2.16000915, 0.98425680, -1.22311163, ..., -0.03393326, 1.91415381, -1.00991392]])\n" - ] - } - ], - "source": [ - "print(xs[1])\n", - "print(encoder_out[1])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0504e3f8", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/.notebook/wenet_model.ipynb b/.notebook/wenet_model.ipynb deleted file mode 100644 index 8e10b6c4b..000000000 --- a/.notebook/wenet_model.ipynb +++ /dev/null @@ -1,5015 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "cfb832c0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/workspace/wenet\n" - ] - }, - { - "data": { - "text/plain": [ - "'/workspace/wenet'" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%cd /workspace/wenet/\n", - "%pwd" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "62277538", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "import argparse\n", - "import copy\n", - "import logging\n", - "import os\n", - "\n", - "import torch\n", - "import torch.distributed as dist\n", - "import torch.optim as optim\n", - "import yaml\n", - "from tensorboardX import SummaryWriter\n", - "from torch.utils.data import DataLoader\n", - "\n", - "from wenet.dataset.dataset import AudioDataset, CollateFunc\n", - "from wenet.transformer.asr_model import init_asr_model\n", - "from wenet.utils.checkpoint import load_checkpoint, save_checkpoint\n", - "from wenet.utils.executor import Executor\n", - "from wenet.utils.scheduler import WarmupLR\n", - "\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = \"0\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "2f6ea33a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'config': 'examples/aishell/s0/conf/train_conformer.yaml', 'train_data': 'examples/aishell/s0/raw_wav/train/format.data', 'cv_data': 'examples/aishell/s0/raw_wav/dev/format.data', 'gpu': -1, 'model_dir': None, 'checkpoint': None, 'tensorboard_dir': 'tensorboard', 'rank': 0, 'world_size': -1, 'dist_backend': 'nccl', 'init_method': None, 'num_workers': 0, 'pin_memory': False, 'cmvn': 'examples/aishell/s0/raw_wav/train/global_cmvn'}\n" - ] - } - ], - "source": [ - "parser = argparse.ArgumentParser(description='training your network')\n", - "parser.add_argument('--config', default=\"examples/aishell/s0/conf/train_conformer.yaml\", help='config file')\n", - "parser.add_argument('--train_data', default=\"examples/aishell/s0/raw_wav/train/format.data\", help='train data file')\n", - "parser.add_argument('--cv_data', default=\"examples/aishell/s0/raw_wav/dev/format.data\", help='cv data file')\n", - "parser.add_argument('--gpu',\n", - " type=int,\n", - " default=-1,\n", - " help='gpu id for this local rank, -1 for cpu')\n", - "parser.add_argument('--model_dir' , help='save model dir')\n", - "parser.add_argument('--checkpoint', help='checkpoint model')\n", - "parser.add_argument('--tensorboard_dir',\n", - " default='tensorboard',\n", - " help='tensorboard log dir')\n", - "parser.add_argument('--ddp.rank',\n", - " dest='rank',\n", - " default=0,\n", - " type=int,\n", - " help='global rank for distributed training')\n", - "parser.add_argument('--ddp.world_size',\n", - " dest='world_size',\n", - " default=-1,\n", - " type=int,\n", - " help='''number of total processes/gpus for\n", - " distributed training''')\n", - "parser.add_argument('--ddp.dist_backend',\n", - " dest='dist_backend',\n", - " default='nccl',\n", - " choices=['nccl', 'gloo'],\n", - " help='distributed backend')\n", - "parser.add_argument('--ddp.init_method',\n", - " dest='init_method',\n", - " default=None,\n", - " help='ddp init method')\n", - "parser.add_argument('--num_workers',\n", - " default=0,\n", - " type=int,\n", - " help='num of subprocess workers for reading')\n", - "parser.add_argument('--pin_memory',\n", - " action='store_true',\n", - " default=False,\n", - " help='Use pinned memory buffers used for reading')\n", - "parser.add_argument('--cmvn', default=\"examples/aishell/s0/raw_wav/train/global_cmvn\", help='global cmvn file')\n", - "\n", - "args = parser.parse_args([])\n", - "print(vars(args))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "f5d6af9b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Namespace(checkpoint=None, cmvn='examples/aishell/s0/raw_wav/train/global_cmvn', config='examples/aishell/s0/conf/train_conformer.yaml', cv_data='examples/aishell/s0/raw_wav/dev/format.data', dist_backend='nccl', gpu=-1, init_method=None, model_dir=None, num_workers=0, pin_memory=False, rank=0, tensorboard_dir='tensorboard', train_data='examples/aishell/s0/raw_wav/train/format.data', world_size=-1)\n" - ] - } - ], - "source": [ - "# Set random seed\n", - "torch.manual_seed(777)\n", - "print(args)\n", - "with open(args.config, 'r') as fin:\n", - " configs = yaml.load(fin, Loader=yaml.FullLoader)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "264bd353", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "7507 batches\n", - "896\n" - ] - } - ], - "source": [ - "raw_wav = configs['raw_wav']\n", - "\n", - "train_collate_func = CollateFunc(**configs['collate_conf'],\n", - " raw_wav=raw_wav)\n", - "\n", - "cv_collate_conf = copy.deepcopy(configs['collate_conf'])\n", - "# no augmenation on cv set\n", - "cv_collate_conf['spec_aug'] = False\n", - "cv_collate_conf['spec_sub'] = False\n", - "if raw_wav:\n", - " cv_collate_conf['feature_dither'] = 0.0\n", - " cv_collate_conf['speed_perturb'] = False\n", - " cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0\n", - "cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav)\n", - "\n", - "dataset_conf = configs.get('dataset_conf', {})\n", - "train_dataset = AudioDataset(args.train_data,\n", - " **dataset_conf,\n", - " raw_wav=raw_wav)\n", - "cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav)\n", - "# 120098 data/train/wav.scp\n", - "print(len(train_dataset), 'batches')\n", - "# 14326 data/dev/wav.scp\n", - "print(len(cv_dataset))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "88863d3c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "896\n" - ] - } - ], - "source": [ - "train_sampler = None\n", - "cv_sampler = None\n", - "train_data_loader = DataLoader(train_dataset,\n", - " collate_fn=train_collate_func,\n", - " sampler=train_sampler,\n", - " #shuffle=(train_sampler is None),\n", - " shuffle=False,\n", - " pin_memory=args.pin_memory,\n", - " batch_size=1,\n", - " num_workers=args.num_workers)\n", - "cv_data_loader = DataLoader(cv_dataset,\n", - " collate_fn=cv_collate_func,\n", - " sampler=cv_sampler,\n", - " shuffle=False,\n", - " batch_size=1,\n", - " pin_memory=args.pin_memory,\n", - " num_workers=args.num_workers)\n", - "print(len(cv_data_loader))" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "10d5acd4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "4233 vocab\n", - "80 feat dim\n" - ] - } - ], - "source": [ - "if raw_wav:\n", - " input_dim = configs['collate_conf']['feature_extraction_conf'][\n", - " 'mel_bins']\n", - "else:\n", - " input_dim = train_dataset.input_dim\n", - "vocab_size = train_dataset.output_dim\n", - "print(vocab_size, 'vocab')\n", - "print(input_dim , 'feat dim')" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "0380ef5a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "examples/aishell/s0/raw_wav/train/global_cmvn\n" - ] - } - ], - "source": [ - "# Save configs to model_dir/train.yaml for inference and export\n", - "configs['input_dim'] = input_dim\n", - "configs['output_dim'] = vocab_size\n", - "configs['cmvn_file'] = args.cmvn\n", - "configs['is_json_cmvn'] = raw_wav\n", - "print(args.cmvn)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "15ebf2bf", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(80,)\n", - "(80,)\n", - "[ 9.87176362 9.93891555 10.23818678 10.85971412 11.68652649 12.2548801\n", - " 12.65768161 12.86138996 12.80733912 12.56625574 12.32007066 12.13879205\n", - " 12.31318868 12.55255216 12.61223855 12.56974526 12.38972728 12.14383338\n", - " 12.09285066 11.79395822 11.62259065 11.9263303 11.8154422 11.95122567\n", - " 11.83180553 11.88788759 11.79014437 11.88072035 11.90005711 11.97348142\n", - " 12.00982189 12.00881339 12.02619706 12.10479646 12.21555081 12.34399304\n", - " 12.45014401 12.4966879 12.48653775 12.3550783 12.39291732 12.2553737\n", - " 12.26496277 12.25314244 12.32545763 12.43359839 12.54867439 12.6763342\n", - " 12.80920698 12.92934681 12.96115138 12.96883353 12.99593057 13.04728142\n", - " 13.0588804 13.05737948 12.99921175 12.93402238 12.87429219 12.71652995\n", - " 12.48942004 12.27478385 12.26163069 12.28631891 12.31956049 12.4229073\n", - " 12.51480191 12.5785164 12.64719411 12.73762568 12.80017069 12.86872766\n", - " 12.96666856 13.06478583 13.15915908 13.27284306 13.31081821 13.23904279\n", - " 12.87936075 11.18310185]\n", - "[0.61219383 0.49700994 0.33439025 0.31503119 0.29640823 0.28411759\n", - " 0.26972922 0.25610475 0.24632936 0.24610228 0.24733299 0.24426536\n", - " 0.23751781 0.22987273 0.22659963 0.2268427 0.23059031 0.23420722\n", - " 0.23771761 0.2411352 0.24404673 0.24557175 0.24724932 0.25055198\n", - " 0.25482755 0.2602407 0.26363878 0.26503898 0.2648467 0.26435072\n", - " 0.26353625 0.26364794 0.26411054 0.26339948 0.26212082 0.26146597\n", - " 0.26196556 0.26365859 0.26592959 0.26963884 0.27392766 0.27818809\n", - " 0.28313664 0.2863325 0.28713431 0.28649323 0.28636648 0.2867843\n", - " 0.28635904 0.28562022 0.28492711 0.28429201 0.28402977 0.28401045\n", - " 0.28560797 0.28728033 0.28969549 0.29351627 0.29826453 0.30572631\n", - " 0.31811682 0.32887739 0.33288219 0.33326245 0.33014147 0.32403202\n", - " 0.31903576 0.31316258 0.30741037 0.30370692 0.30204833 0.30049064\n", - " 0.29901079 0.29824511 0.29812308 0.29753329 0.29779342 0.30175296\n", - " 0.30955538 0.32904205]\n" - ] - } - ], - "source": [ - "import json\n", - "import math\n", - "import numpy as np\n", - "def _load_json_cmvn(json_cmvn_file):\n", - " \"\"\" Load the json format cmvn stats file and calculate cmvn\n", - "\n", - " Args:\n", - " json_cmvn_file: cmvn stats file in json format\n", - "\n", - " Returns:\n", - " a numpy array of [means, vars]\n", - " \"\"\"\n", - " with open(json_cmvn_file) as f:\n", - " cmvn_stats = json.load(f)\n", - "\n", - " means = cmvn_stats['mean_stat']\n", - " variance = cmvn_stats['var_stat']\n", - " count = cmvn_stats['frame_num']\n", - " for i in range(len(means)):\n", - " means[i] /= count\n", - " variance[i] = variance[i] / count - means[i] * means[i]\n", - " if variance[i] < 1.0e-20:\n", - " variance[i] = 1.0e-20\n", - " variance[i] = 1.0 / math.sqrt(variance[i])\n", - " cmvn = np.array([means, variance])\n", - " return cmvn\n", - "\n", - "\n", - "def _load_kaldi_cmvn(kaldi_cmvn_file):\n", - " \"\"\" Load the kaldi format cmvn stats file and calculate cmvn\n", - "\n", - " Args:\n", - " kaldi_cmvn_file: kaldi text style global cmvn file, which\n", - " is generated by:\n", - " compute-cmvn-stats --binary=false scp:feats.scp global_cmvn\n", - "\n", - " Returns:\n", - " a numpy array of [means, vars]\n", - " \"\"\"\n", - " means = []\n", - " variance = []\n", - " with open(kaldi_cmvn_file, 'r') as fid:\n", - " # kaldi binary file start with '\\0B'\n", - " if fid.read(2) == '\\0B':\n", - " logger.error('kaldi cmvn binary file is not supported, please '\n", - " 'recompute it by: compute-cmvn-stats --binary=false '\n", - " ' scp:feats.scp global_cmvn')\n", - " sys.exit(1)\n", - " fid.seek(0)\n", - " arr = fid.read().split()\n", - " assert (arr[0] == '[')\n", - " assert (arr[-2] == '0')\n", - " assert (arr[-1] == ']')\n", - " feat_dim = int((len(arr) - 2 - 2) / 2)\n", - " for i in range(1, feat_dim + 1):\n", - " means.append(float(arr[i]))\n", - " count = float(arr[feat_dim + 1])\n", - " for i in range(feat_dim + 2, 2 * feat_dim + 2):\n", - " variance.append(float(arr[i]))\n", - "\n", - " for i in range(len(means)):\n", - " means[i] /= count\n", - " variance[i] = variance[i] / count - means[i] * means[i]\n", - " if variance[i] < 1.0e-20:\n", - " variance[i] = 1.0e-20\n", - " variance[i] = 1.0 / math.sqrt(variance[i])\n", - " cmvn = np.array([means, variance])\n", - " return cmvn\n", - "\n", - "\n", - "def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):\n", - " npzfile = np.load(npz_cmvn_file)\n", - " means = npzfile[\"mean\"] #(1, D)\n", - " std = npzfile[\"std\"] #(1, D)\n", - " std = np.clip(std, eps, None)\n", - " variance = 1.0 / std\n", - " cmvn = np.array([means, variance])\n", - " return cmvn\n", - "\n", - "\n", - "def load_cmvn(cmvn_file: str, filetype: str):\n", - " \"\"\"load cmvn from file.\n", - "\n", - " Args:\n", - " cmvn_file (str): cmvn path.\n", - " filetype (str): file type, optional[npz, json, kaldi].\n", - "\n", - " Raises:\n", - " ValueError: file type not support.\n", - "\n", - " Returns:\n", - " Tuple[np.ndarray, np.ndarray]: mean, istd\n", - " \"\"\"\n", - " assert filetype in ['npz', 'json', 'kaldi'], filetype\n", - " filetype = filetype.lower()\n", - " if filetype == \"json\":\n", - " cmvn = _load_json_cmvn(cmvn_file)\n", - " elif filetype == \"kaldi\":\n", - " cmvn = _load_kaldi_cmvn(cmvn_file)\n", - " elif filetype == \"npz\":\n", - " cmvn = _load_npz_cmvn(cmvn_file)\n", - " else:\n", - " raise ValueError(f\"cmvn file type no support: {filetype}\")\n", - " return cmvn[0], cmvn[1]\n", - "\n", - "mean, istd = load_cmvn(args.cmvn, 'json')\n", - "print(mean.shape)\n", - "print(istd.shape)\n", - "print(mean)\n", - "print(istd)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "3cfa5e23", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ASRModel(\n", - " (encoder): ConformerEncoder(\n", - " (global_cmvn): GlobalCMVN()\n", - " (embed): Conv2dSubsampling4(\n", - " (conv): Sequential(\n", - " (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2))\n", - " (1): ReLU()\n", - " (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n", - " (3): ReLU()\n", - " )\n", - " (out): Sequential(\n", - " (0): Linear(in_features=4864, out_features=256, bias=True)\n", - " )\n", - " (pos_enc): RelPositionalEncoding(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (encoders): ModuleList(\n", - " (0): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (1): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (2): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (3): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (4): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (5): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (6): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (7): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (8): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (9): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (10): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (11): ConformerEncoderLayer(\n", - " (self_attn): RelPositionMultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (feed_forward_macaron): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): Swish()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (conv_module): ConvolutionModule(\n", - " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", - " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", - " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", - " (activation): Swish()\n", - " )\n", - " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " )\n", - " )\n", - " (decoder): TransformerDecoder(\n", - " (embed): Sequential(\n", - " (0): Embedding(4233, 256)\n", - " (1): PositionalEncoding(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (output_layer): Linear(in_features=256, out_features=4233, bias=True)\n", - " (decoders): ModuleList(\n", - " (0): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (1): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (2): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (3): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (4): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " (5): DecoderLayer(\n", - " (self_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (src_attn): MultiHeadedAttention(\n", - " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", - " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (feed_forward): PositionwiseFeedForward(\n", - " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", - " (activation): ReLU()\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", - " )\n", - " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", - " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", - " )\n", - " )\n", - " )\n", - " (ctc): CTC(\n", - " (ctc_lo): Linear(in_features=256, out_features=4233, bias=True)\n", - " (ctc_loss): CTCLoss()\n", - " )\n", - " (criterion_att): LabelSmoothingLoss(\n", - " (criterion): KLDivLoss()\n", - " )\n", - ")\n" - ] - } - ], - "source": [ - "# Init asr model from configs\n", - "model = init_asr_model(configs)\n", - "print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "3c780af5", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "def summary(layer, print_func=print):\n", - " num_params = num_elements = 0\n", - " for name, param in layer.state_dict().items():\n", - " if print_func:\n", - " print_func(\n", - " \"{} | {} | {}\".format(name, param.shape, np.prod(param.shape)))\n", - " num_elements += np.prod(param.shape)\n", - " num_params += 1\n", - " if print_func:\n", - " print_func(\n", - " f\"Total parameters: {num_params}, {num_elements} elements.\"\n", - " )\n", - " \n", - "def print_params(model, print_func=print):\n", - " if print_func is None:\n", - " return\n", - " total = 0.0\n", - " num_params = 0.0\n", - " for n, p in model.named_parameters():\n", - " msg = f\"{n} | {p.shape} | {np.prod(p.shape)} | {p.requires_grad}\"\n", - " total += np.prod(p.shape)\n", - " num_params += 1\n", - " if print_func:\n", - " print_func(msg)\n", - " if print_func:\n", - " print_func(f\"Total parameters: {num_params}, {total} elements.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "e159a200", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.global_cmvn.mean | torch.Size([80]) | 80\n", - "encoder.global_cmvn.istd | torch.Size([80]) | 80\n", - "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304\n", - "encoder.embed.conv.0.bias | torch.Size([256]) | 256\n", - "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824\n", - "encoder.embed.conv.2.bias | torch.Size([256]) | 256\n", - "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184\n", - "encoder.embed.out.0.bias | torch.Size([256]) | 256\n", - "encoder.after_norm.weight | torch.Size([256]) | 256\n", - "encoder.after_norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.0.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.1.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.2.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.3.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.4.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.5.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.6.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.7.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.8.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.9.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.10.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", - "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", - "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", - "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.running_mean | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.running_var | torch.Size([256]) | 256\n", - "encoder.encoders.11.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256\n", - "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256\n", - "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072\n", - "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256\n", - "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648\n", - "decoder.after_norm.weight | torch.Size([256]) | 256\n", - "decoder.after_norm.bias | torch.Size([256]) | 256\n", - "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648\n", - "decoder.output_layer.bias | torch.Size([4233]) | 4233\n", - "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", - "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", - "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", - "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", - "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256\n", - "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256\n", - "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", - "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256\n", - "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648\n", - "ctc.ctc_lo.bias | torch.Size([4233]) | 4233\n", - "Total parameters: 701, 49355454.0 elements.\n" - ] - } - ], - "source": [ - "summary(model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8494c6ab", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "0648a969", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304 | True\n", - "encoder.embed.conv.0.bias | torch.Size([256]) | 256 | True\n", - "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824 | True\n", - "encoder.embed.conv.2.bias | torch.Size([256]) | 256 | True\n", - "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184 | True\n", - "encoder.embed.out.0.bias | torch.Size([256]) | 256 | True\n", - "encoder.after_norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.after_norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", - "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", - "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", - "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256 | True\n", - "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", - "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256 | True\n", - "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648 | True\n", - "decoder.after_norm.weight | torch.Size([256]) | 256 | True\n", - "decoder.after_norm.bias | torch.Size([256]) | 256 | True\n", - "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648 | True\n", - "decoder.output_layer.bias | torch.Size([4233]) | 4233 | True\n", - "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", - "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", - "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", - "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256 | True\n", - "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", - "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256 | True\n", - "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648 | True\n", - "ctc.ctc_lo.bias | torch.Size([4233]) | 4233 | True\n", - "Total parameters: 663.0, 49349138.0 elements.\n" - ] - } - ], - "source": [ - "print_params(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "5ad6de2a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n", - "torch.Size([16, 207, 80])\n", - "tensor([[[ 8.9946, 9.5383, 9.1916, ..., 10.5074, 9.5633, 8.2564],\n", - " [ 9.7988, 10.4052, 9.2651, ..., 10.2512, 9.5440, 8.8738],\n", - " [10.6891, 10.3955, 8.0535, ..., 9.9067, 10.0649, 8.0509],\n", - " ...,\n", - " [ 9.2180, 9.6507, 8.5053, ..., 9.6872, 8.7425, 7.9865],\n", - " [10.1291, 9.9352, 9.3798, ..., 9.5639, 9.8260, 8.9795],\n", - " [ 9.0955, 7.1338, 9.4680, ..., 9.4727, 9.0212, 7.4479]],\n", - "\n", - " [[11.4310, 10.6719, 6.0841, ..., 9.3827, 8.7297, 7.5316],\n", - " [ 9.7317, 7.8105, 7.5715, ..., 10.0430, 9.2436, 7.3541],\n", - " [10.6502, 10.6006, 8.4678, ..., 9.2814, 9.1869, 8.0703],\n", - " ...,\n", - " [ 9.0970, 9.2637, 8.0753, ..., 8.4318, 8.3705, 8.0029],\n", - " [10.4617, 10.1478, 6.7693, ..., 9.7794, 9.5775, 8.0807],\n", - " [ 7.7944, 5.6211, 7.9751, ..., 9.9972, 9.8497, 8.0313]],\n", - "\n", - " [[ 7.3456, 7.8964, 7.5796, ..., 11.6310, 10.4513, 9.1236],\n", - " [ 8.6287, 8.4631, 7.4992, ..., 12.4160, 10.9757, 8.9426],\n", - " [ 9.8314, 10.2813, 8.9724, ..., 12.1387, 10.4017, 9.0055],\n", - " ...,\n", - " [ 7.0896, 7.4055, 6.8143, ..., 9.3252, 9.2732, 8.3534],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - " ...,\n", - "\n", - " [[10.9332, 10.4644, 7.7203, ..., 10.3488, 9.3023, 7.1553],\n", - " [10.4499, 9.9070, 9.0293, ..., 9.9525, 9.4141, 7.5593],\n", - " [10.4877, 9.8126, 9.8952, ..., 9.5866, 9.3413, 7.7849],\n", - " ...,\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - " [[ 9.9444, 9.5859, 8.2203, ..., 11.5886, 11.0450, 8.8171],\n", - " [ 7.6784, 8.3224, 7.5330, ..., 11.0551, 10.5357, 9.2746],\n", - " [ 8.6262, 9.6759, 9.8410, ..., 11.3788, 10.9221, 8.9914],\n", - " ...,\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", - "\n", - " [[ 8.1079, 7.7590, 6.7103, ..., 12.6506, 11.4662, 11.0615],\n", - " [11.3803, 11.2220, 8.6589, ..., 12.8106, 12.2222, 11.6893],\n", - " [10.6777, 9.9206, 8.0461, ..., 13.5729, 12.5624, 11.1550],\n", - " ...,\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", - " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n", - "tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n", - " 166, 163], dtype=torch.int32)\n", - "tensor([[2995, 3116, 1209, 565, -1, -1],\n", - " [ 236, 1176, 331, 66, 3925, 4077],\n", - " [2693, 524, 234, 1145, 366, -1],\n", - " [3875, 4211, 3062, 700, -1, -1],\n", - " [ 272, 987, 1134, 494, 2959, -1],\n", - " [1936, 3715, 120, 2553, 2695, 2710],\n", - " [ 25, 1149, 3930, -1, -1, -1],\n", - " [1753, 1778, 1237, 482, 3925, 110],\n", - " [3703, 2, 565, 3827, -1, -1],\n", - " [1150, 2734, 10, 2478, 3490, -1],\n", - " [ 426, 811, 95, 489, 144, -1],\n", - " [2313, 2006, 489, 975, -1, -1],\n", - " [3702, 3414, 205, 1488, 2966, 1347],\n", - " [ 70, 1741, 702, 1666, -1, -1],\n", - " [ 703, 1778, 1030, 849, -1, -1],\n", - " [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n", - "tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)\n" - ] - } - ], - "source": [ - "for batch in cv_data_loader:\n", - " keys, feat, text, feat_len, text_len = batch\n", - " print(keys)\n", - " print(feat.shape)\n", - " print(feat)\n", - " print(feat_len)\n", - " print(text)\n", - " print(text_len)\n", - " np.savez('data.npz', keys=keys, feat=feat.numpy(), feat_len=feat_len.numpy(), text=text.numpy(), text_len=text_len.numpy())\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "852a9c95", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CODE_OF_CONDUCT.md data.npz install.sh README.md\t tools\r\n", - "CONTRIBUTING.md docs LICENSE\t requirements.txt venv\r\n", - "CPPLINT.cfg\t examples Makefile\t runtime\t wenet\r\n" - ] - } - ], - "source": [ - "!ls\n", - "!cp data.npz /workspace/DeepSpeech-2.x/.notebook" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "cde24c4e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(111.9988)\n", - "tensor(830.9634, grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True])\n", - "tensor(669.4633, grad_fn=)\n", - "tensor(142.4888, grad_fn=) tensor(41.8415, grad_fn=) tensor(377.3326, grad_fn=)\n" - ] - } - ], - "source": [ - "model.cpu().eval()\n", - "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", - " text, text_len)\n", - "print(total_loss, attention_loss, ctc_loss )" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "be5b2a2c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cpu\n" - ] - } - ], - "source": [ - "print(total_loss.device)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "5b791771", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(112., device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "cuda:0\n", - "142.4888 41.84146 377.33258\n" - ] - } - ], - "source": [ - "model.cuda().eval()\n", - "feat=feat.cuda()\n", - "feat_len=feat_len.cuda()\n", - "text=text.cuda()\n", - "text_len=text_len.cuda()\n", - "\n", - "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", - " text, text_len)\n", - "print(total_loss.device)\n", - "print(total_loss.cpu().data.numpy(), attention_loss.cpu().data.numpy(), ctc_loss.cpu().data.numpy() )" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "1baef537", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([16, 51, 256])\n", - "torch.Size([16, 1, 51])\n", - "tensor([[-0.7019, 0.5625, 0.6880, ..., 1.1237, 0.7804, 1.1369],\n", - " [-0.7788, 0.3913, 0.7189, ..., 1.2519, 0.8862, 1.3173],\n", - " [-0.9591, 0.6346, 0.8767, ..., 0.9818, 0.7440, 1.2903],\n", - " ...,\n", - " [-1.0732, 0.6724, 0.9230, ..., 0.9075, 0.8177, 1.3240],\n", - " [-1.1654, 0.6820, 0.6939, ..., 1.2238, 0.8028, 1.4507],\n", - " [-1.2732, 0.7146, 0.7582, ..., 0.9415, 0.8775, 1.2623]],\n", - " device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n", - "print(encoder_out.shape)\n", - "print(encoder_mask.shape)\n", - "print(encoder_out[0])\n", - "\n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/encoder.npz',\n", - " mask=encoder_mask.cpu().detach().numpy(), \n", - " out=encoder_out.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e22c782", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "30b6b946", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 9.871763 9.938915 10.238187 10.8597145 11.686526 12.25488\n", - " 12.657681 12.86139 12.807339 12.566256 12.32007 12.138792\n", - " 12.313189 12.552552 12.612239 12.569745 12.389728 12.143833\n", - " 12.092851 11.793959 11.622591 11.926331 11.815442 11.951225\n", - " 11.831805 11.887888 11.790144 11.88072 11.900057 11.973481\n", - " 12.009822 12.008814 12.026197 12.104796 12.21555 12.343993\n", - " 12.450144 12.496688 12.486538 12.355079 12.392918 12.255374\n", - " 12.264963 12.253142 12.325458 12.4335985 12.548675 12.676334\n", - " 12.809207 12.929347 12.961151 12.968834 12.995931 13.047281\n", - " 13.058881 13.05738 12.999211 12.934022 12.874292 12.71653\n", - " 12.48942 12.274784 12.261631 12.286319 12.31956 12.422907\n", - " 12.514802 12.578516 12.647194 12.737626 12.800171 12.868728\n", - " 12.966668 13.064786 13.159159 13.272843 13.310819 13.239043\n", - " 12.879361 11.183102 ] float32\n", - "encoder.embed.out.0.weight: (256, 4864) -> (4864, 256)\n", - "encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.0.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.0.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.0.conv_module.norm.running_mean -> encoder.encoders.0.conv_module.norm._mean\n", - "encoder.encoders.0.conv_module.norm.running_var -> encoder.encoders.0.conv_module.norm._variance\n", - "encoder.encoders.0.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.1.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.1.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.1.conv_module.norm.running_mean -> encoder.encoders.1.conv_module.norm._mean\n", - "encoder.encoders.1.conv_module.norm.running_var -> encoder.encoders.1.conv_module.norm._variance\n", - "encoder.encoders.1.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.2.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.2.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.2.conv_module.norm.running_mean -> encoder.encoders.2.conv_module.norm._mean\n", - "encoder.encoders.2.conv_module.norm.running_var -> encoder.encoders.2.conv_module.norm._variance\n", - "encoder.encoders.2.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.3.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.3.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.3.conv_module.norm.running_mean -> encoder.encoders.3.conv_module.norm._mean\n", - "encoder.encoders.3.conv_module.norm.running_var -> encoder.encoders.3.conv_module.norm._variance\n", - "encoder.encoders.3.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.4.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.4.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.4.conv_module.norm.running_mean -> encoder.encoders.4.conv_module.norm._mean\n", - "encoder.encoders.4.conv_module.norm.running_var -> encoder.encoders.4.conv_module.norm._variance\n", - "encoder.encoders.4.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.5.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.5.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.5.conv_module.norm.running_mean -> encoder.encoders.5.conv_module.norm._mean\n", - "encoder.encoders.5.conv_module.norm.running_var -> encoder.encoders.5.conv_module.norm._variance\n", - "encoder.encoders.5.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.6.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.6.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.6.conv_module.norm.running_mean -> encoder.encoders.6.conv_module.norm._mean\n", - "encoder.encoders.6.conv_module.norm.running_var -> encoder.encoders.6.conv_module.norm._variance\n", - "encoder.encoders.6.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.7.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.7.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.7.conv_module.norm.running_mean -> encoder.encoders.7.conv_module.norm._mean\n", - "encoder.encoders.7.conv_module.norm.running_var -> encoder.encoders.7.conv_module.norm._variance\n", - "encoder.encoders.7.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.8.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.8.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.8.conv_module.norm.running_mean -> encoder.encoders.8.conv_module.norm._mean\n", - "encoder.encoders.8.conv_module.norm.running_var -> encoder.encoders.8.conv_module.norm._variance\n", - "encoder.encoders.8.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.9.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.9.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.9.conv_module.norm.running_mean -> encoder.encoders.9.conv_module.norm._mean\n", - "encoder.encoders.9.conv_module.norm.running_var -> encoder.encoders.9.conv_module.norm._variance\n", - "encoder.encoders.9.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.10.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.10.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.10.conv_module.norm.running_mean -> encoder.encoders.10.conv_module.norm._mean\n", - "encoder.encoders.10.conv_module.norm.running_var -> encoder.encoders.10.conv_module.norm._variance\n", - "encoder.encoders.10.concat_linear.weight: (256, 512) -> (512, 256)\n", - "encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", - "encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.11.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", - "encoder.encoders.11.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", - "encoder.encoders.11.conv_module.norm.running_mean -> encoder.encoders.11.conv_module.norm._mean\n", - "encoder.encoders.11.conv_module.norm.running_var -> encoder.encoders.11.conv_module.norm._variance\n", - "encoder.encoders.11.concat_linear.weight: (256, 512) -> (512, 256)\n", - "decoder.output_layer.weight: (4233, 256) -> (256, 4233)\n", - "decoder.decoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.0.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.0.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.1.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.1.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.2.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.2.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.3.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.3.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.4.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.4.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", - "decoder.decoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", - "decoder.decoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", - "decoder.decoders.5.concat_linear1.weight: (256, 512) -> (512, 256)\n", - "decoder.decoders.5.concat_linear2.weight: (256, 512) -> (512, 256)\n", - "ctc.ctc_lo.weight: (4233, 256) -> (256, 4233)\n" - ] - } - ], - "source": [ - "# dump torch model to paddle\n", - "import numpy as np\n", - "state_dict = model.state_dict()\n", - "paddle_state_dict = {}\n", - "\n", - "for n, p in state_dict.items():\n", - " name_change=True\n", - "\n", - " if 'norm.running_mean' in n:\n", - " new_n = n.replace('norm.running_', 'norm._')\n", - " elif 'norm.running_var' in n:\n", - " new_n = n.replace('norm.running_var', 'norm._variance')\n", - " else:\n", - " name_change=False\n", - " new_n = n\n", - " \n", - " if name_change:\n", - " print(f\"{n} -> {new_n}\")\n", - " \n", - " p = p.cpu().detach().numpy()\n", - " if n.endswith('weight') and p.ndim == 2 and 'embed.0.weight' not in n:\n", - " new_p = p.T\n", - " print(f\"{n}: {p.shape} -> {new_p.shape}\")\n", - " else:\n", - " new_p = p\n", - " \n", - " if 'global_cmvn.mean' in n:\n", - " print(p, p.dtype)\n", - " \n", - " paddle_state_dict[new_n] = new_p\n", - " \n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n", - " state=paddle_state_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7307dc5b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "d99b29bc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(377.3326, device='cuda:0', grad_fn=)\n", - "None\n", - "[[ 3.16902351e+00 -1.51765049e-02 4.91097234e-02 ... -2.47973716e-03\n", - " -5.93366381e-03 -7.26613170e-03]\n", - " [-1.74185038e+00 7.75875803e-03 -4.49435972e-02 ... 9.92415240e-04\n", - " 2.46338220e-03 2.31891591e-03]\n", - " [-2.33343077e+00 1.30476682e-02 -2.66557615e-02 ... 2.27533933e-03\n", - " 5.76929189e-03 7.48792710e-03]\n", - " ...\n", - " [-4.30356789e+00 2.46056803e-02 -9.00955945e-02 ... 4.43160534e-03\n", - " 1.16123557e-02 1.44716976e-02]\n", - " [-3.36919212e+00 1.73155665e-02 -6.36875406e-02 ... 3.28367390e-03\n", - " 8.58021621e-03 1.07796099e-02]\n", - " [-6.62039661e+00 3.49958315e-02 -1.23963736e-01 ... 6.36674836e-03\n", - " 1.60815325e-02 2.03892551e-02]]\n", - "[-4.3777566e+00 2.3245990e-02 -9.3339972e-02 ... 4.2569702e-03\n", - " 1.0920014e-02 1.3787906e-02]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":6: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.\n", - " print(loss_ctc.grad)\n" - ] - } - ], - "source": [ - "encoder_out_lens = encoder_mask.squeeze(1).sum(1)\n", - "loss_ctc = model.ctc(encoder_out, encoder_out_lens, text, text_len)\n", - "print(loss_ctc)\n", - "dir(loss_ctc)\n", - "loss_ctc.backward()\n", - "print(loss_ctc.grad)\n", - "#print(model.ctc.ctc_lo.weight.grad)\n", - "print(model.ctc.ctc_lo.weight.grad.T.cpu().data.numpy())\n", - "print(model.ctc.ctc_lo.bias.grad.cpu().data.numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "49b05d6d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor(112., device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "tensor(41.8415, device='cuda:0', grad_fn=) 0.0\n" - ] - } - ], - "source": [ - "loss_att, acc_att = model._calc_att_loss(encoder_out, encoder_mask,\n", - " text, text_len)\n", - "print(loss_att, acc_att)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "413b413f", - "metadata": {}, - "outputs": [], - "source": [ - "def pad_list(xs, pad_value: int):\n", - " n_batch = len(xs)\n", - " max_len = max([x.size(0) for x in xs])\n", - " pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)\n", - " pad = pad.fill_(pad_value)\n", - " for i in range(n_batch):\n", - " pad[i, :xs[i].size(0)] = xs[i]\n", - "\n", - " return pad\n", - "\n", - "def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,\n", - " ignore_id: int):\n", - "\n", - " _sos = torch.tensor([sos],\n", - " dtype=torch.long,\n", - " requires_grad=False,\n", - " device=ys_pad.device)\n", - " _eos = torch.tensor([eos],\n", - " dtype=torch.long,\n", - " requires_grad=False,\n", - " device=ys_pad.device)\n", - " ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n", - " ys_in = [torch.cat([_sos, y], dim=0) for y in ys]\n", - " ys_out = [torch.cat([y, _eos], dim=0) for y in ys]\n", - " return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "ff0c2400", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[4232, 2995, 3116, 1209, 565, 4232, 4232],\n", - " [4232, 236, 1176, 331, 66, 3925, 4077],\n", - " [4232, 2693, 524, 234, 1145, 366, 4232],\n", - " [4232, 3875, 4211, 3062, 700, 4232, 4232],\n", - " [4232, 272, 987, 1134, 494, 2959, 4232],\n", - " [4232, 1936, 3715, 120, 2553, 2695, 2710],\n", - " [4232, 25, 1149, 3930, 4232, 4232, 4232],\n", - " [4232, 1753, 1778, 1237, 482, 3925, 110],\n", - " [4232, 3703, 2, 565, 3827, 4232, 4232],\n", - " [4232, 1150, 2734, 10, 2478, 3490, 4232],\n", - " [4232, 426, 811, 95, 489, 144, 4232],\n", - " [4232, 2313, 2006, 489, 975, 4232, 4232],\n", - " [4232, 3702, 3414, 205, 1488, 2966, 1347],\n", - " [4232, 70, 1741, 702, 1666, 4232, 4232],\n", - " [4232, 703, 1778, 1030, 849, 4232, 4232],\n", - " [4232, 814, 1674, 115, 3827, 4232, 4232]], device='cuda:0')\n", - "tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n", - " [ 236, 1176, 331, 66, 3925, 4077, 4232],\n", - " [2693, 524, 234, 1145, 366, 4232, -1],\n", - " [3875, 4211, 3062, 700, 4232, -1, -1],\n", - " [ 272, 987, 1134, 494, 2959, 4232, -1],\n", - " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", - " [ 25, 1149, 3930, 4232, -1, -1, -1],\n", - " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", - " [3703, 2, 565, 3827, 4232, -1, -1],\n", - " [1150, 2734, 10, 2478, 3490, 4232, -1],\n", - " [ 426, 811, 95, 489, 144, 4232, -1],\n", - " [2313, 2006, 489, 975, 4232, -1, -1],\n", - " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", - " [ 70, 1741, 702, 1666, 4232, -1, -1],\n", - " [ 703, 1778, 1030, 849, 4232, -1, -1],\n", - " [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n" - ] - } - ], - "source": [ - "ys_pad = text\n", - "ys_pad_lens = text_len\n", - "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n", - " model.ignore_id)\n", - "ys_in_lens = ys_pad_lens + 1\n", - "print(ys_in_pad)\n", - "print(ys_out_pad)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "3e84da38", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([16, 7, 4233])\n", - "tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n", - " 1.5035e-02, 4.0337e-01],\n", - " [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n", - " -1.4353e-01, -1.0024e+00],\n", - " [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n", - " -1.1672e+00, -2.6849e-01],\n", - " ...,\n", - " [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n", - " 4.5470e-02, -3.7139e-01],\n", - " [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n", - " -7.9347e-04, 4.2538e-01],\n", - " [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n", - " -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", - " ys_in_lens)\n", - "print(decoder_out.shape)\n", - "print(decoder_out[0])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aac441ea", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "5ddbca73", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.float32\n", - "torch.int64\n", - "tensor(112., device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "tensor(41.8415, device='cuda:0', grad_fn=)\n", - "tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n", - " [ 236, 1176, 331, 66, 3925, 4077, 4232],\n", - " [2693, 524, 234, 1145, 366, 4232, -1],\n", - " [3875, 4211, 3062, 700, 4232, -1, -1],\n", - " [ 272, 987, 1134, 494, 2959, 4232, -1],\n", - " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", - " [ 25, 1149, 3930, 4232, -1, -1, -1],\n", - " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", - " [3703, 2, 565, 3827, 4232, -1, -1],\n", - " [1150, 2734, 10, 2478, 3490, 4232, -1],\n", - " [ 426, 811, 95, 489, 144, 4232, -1],\n", - " [2313, 2006, 489, 975, 4232, -1, -1],\n", - " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", - " [ 70, 1741, 702, 1666, 4232, -1, -1],\n", - " [ 703, 1778, 1030, 849, 4232, -1, -1],\n", - " [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n", - "tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n", - " 1.5035e-02, 4.0337e-01],\n", - " [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n", - " -1.4353e-01, -1.0024e+00],\n", - " [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n", - " -1.1672e+00, -2.6849e-01],\n", - " ...,\n", - " [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n", - " 4.5470e-02, -3.7139e-01],\n", - " [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n", - " -7.9347e-04, 4.2538e-01],\n", - " [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n", - " -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "print(decoder_out.dtype)\n", - "print(ys_out_pad.dtype)\n", - "loss_att = model.criterion_att(decoder_out, ys_out_pad)\n", - "print(loss_att)\n", - "print(ys_out_pad)\n", - "print(decoder_out[0])\n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/decoder',\n", - " decoder_out=decoder_out.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "78f98c0b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "8d968cd3", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from torch import nn\n", - "\n", - "\n", - "class LabelSmoothingLoss(nn.Module):\n", - " def __init__(self,\n", - " size: int,\n", - " padding_idx: int,\n", - " smoothing: float,\n", - " normalize_length: bool = False):\n", - " \"\"\"Construct an LabelSmoothingLoss object.\"\"\"\n", - " super(LabelSmoothingLoss, self).__init__()\n", - " self.criterion = nn.KLDivLoss(reduction=\"none\")\n", - " self.padding_idx = padding_idx\n", - " self.confidence = 1.0 - smoothing\n", - " self.smoothing = smoothing\n", - " self.size = size\n", - " self.normalize_length = normalize_length\n", - "\n", - " def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"Compute loss between x and target.\n", - "\n", - " The model outputs and data labels tensors are flatten to\n", - " (batch*seqlen, class) shape and a mask is applied to the\n", - " padding part which should not be calculated for loss.\n", - "\n", - " Args:\n", - " x (torch.Tensor): prediction (batch, seqlen, class)\n", - " target (torch.Tensor):\n", - " target signal masked with self.padding_id (batch, seqlen)\n", - " Returns:\n", - " loss (torch.Tensor) : The KL loss, scalar float value\n", - " \"\"\"\n", - " assert x.size(2) == self.size\n", - " batch_size = x.size(0)\n", - " x = x.view(-1, self.size)\n", - " target = target.view(-1)\n", - " # use zeros_like instead of torch.no_grad() for true_dist,\n", - " # since no_grad() can not be exported by JIT\n", - " true_dist = torch.zeros_like(x)\n", - " true_dist.fill_(self.smoothing / (self.size - 1))\n", - " ignore = target == self.padding_idx # (B,)\n", - " print(self.smoothing / (self.size - 1))\n", - " print(true_dist)\n", - " total = len(target) - ignore.sum().item()\n", - " target = target.masked_fill(ignore, 0) # avoid -1 index\n", - " true_dist.scatter_(1, target.unsqueeze(1), self.confidence)\n", - " print(true_dist.dtype)\n", - " print(true_dist.square().sum())\n", - " kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)\n", - " print(kl.sum())\n", - " denom = total if self.normalize_length else batch_size\n", - " print(ignore)\n", - " numer= kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n", - " print(numer)\n", - " return numer /denom" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "3df340ec", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2.3629489603024576e-05\n", - "tensor([[2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " ...,\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05],\n", - " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", - " 2.3629e-05]], device='cuda:0')\n", - "torch.float32\n", - "tensor(90.7203, device='cuda:0')\n", - "tensor(830.9634, device='cuda:0', grad_fn=)\n", - "tensor([False, False, False, False, False, True, True, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, True, False, False, False,\n", - " False, False, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, True, True, False, False, False, False, False,\n", - " True, True], device='cuda:0')\n", - "tensor(669.4634, device='cuda:0', grad_fn=)\n", - "tensor(41.8415, device='cuda:0', grad_fn=)\n", - "torch.int64\n" - ] - } - ], - "source": [ - "criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n", - "loss_att = criteron(decoder_out, ys_out_pad)\n", - "print(loss_att)\n", - "print(ys_out_pad.dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "badc410d", - "metadata": {}, - "outputs": [ - { - "ename": "RuntimeError", - "evalue": "Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecoder_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \"\"\"\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m allow_unreachable=True) # allow_unreachable flag\n", - "\u001b[0;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time." - ] - } - ], - "source": [ - "loss_att.backward()\n", - "print(loss_att.grad)\n", - "print(decoder_out.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "219eb41f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([ 0.0024, 0.0019, -0.1098, ..., 0.0028, 0.0020, -1.7978],\n", - " device='cuda:0')\n", - "tensor([[ 6.5052e-04, 6.4419e-05, -6.1955e-06, ..., 9.8220e-04,\n", - " -2.5918e-05, 3.3754e-04],\n", - " [ 3.9305e-04, 4.5799e-04, 1.4362e-04, ..., 4.6800e-04,\n", - " 1.6911e-04, 2.7067e-04],\n", - " [-1.3593e-01, 5.2201e-02, 3.2895e-02, ..., 2.4580e-02,\n", - " 1.4590e-01, -4.6850e-02],\n", - " ...,\n", - " [ 1.0434e-03, 4.2251e-04, 6.5688e-04, ..., 1.2144e-03,\n", - " 2.1159e-04, 6.6838e-04],\n", - " [ 6.4997e-04, 4.4301e-04, 4.1550e-04, ..., 1.0420e-03,\n", - " 2.4114e-04, 1.5338e-04],\n", - " [-9.9337e-01, 5.4573e-01, -1.1371e-02, ..., -4.3175e-01,\n", - " -2.7850e-01, -4.4679e-01]], device='cuda:0')\n" - ] - } - ], - "source": [ - "print(model.decoder.output_layer.bias.grad)\n", - "print(model.decoder.output_layer.weight.grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "40d00a54", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[-5.3698e-01, -1.9911e-01, -3.4997e-01, ..., -8.2428e-01,\n", - " -1.0265e+00, -9.6301e-01],\n", - " [-4.4642e-02, 2.3176e-01, -3.2539e-01, ..., -9.0159e-01,\n", - " -1.0325e+00, -7.5987e-01],\n", - " [ 5.0035e-01, 2.2691e-01, -7.3052e-01, ..., -1.0055e+00,\n", - " -8.7123e-01, -1.0306e+00],\n", - " ...,\n", - " [-4.0024e-01, -1.4325e-01, -5.7947e-01, ..., -1.0718e+00,\n", - " -1.2806e+00, -1.0518e+00],\n", - " [ 1.5755e-01, -1.8495e-03, -2.8703e-01, ..., -1.1090e+00,\n", - " -9.4519e-01, -7.2506e-01],\n", - " [-4.7520e-01, -1.3942e+00, -2.5754e-01, ..., -1.1365e+00,\n", - " -1.1943e+00, -1.2290e+00]],\n", - "\n", - " [[ 9.5454e-01, 3.6428e-01, -1.3891e+00, ..., -1.1637e+00,\n", - " -1.2845e+00, -1.2015e+00],\n", - " [-8.5735e-02, -1.0579e+00, -8.9173e-01, ..., -9.6441e-01,\n", - " -1.1255e+00, -1.2599e+00],\n", - " [ 4.7654e-01, 3.2887e-01, -5.9201e-01, ..., -1.1942e+00,\n", - " -1.1430e+00, -1.0242e+00],\n", - " ...,\n", - " [-4.7431e-01, -3.3559e-01, -7.2326e-01, ..., -1.4506e+00,\n", - " -1.3957e+00, -1.0464e+00],\n", - " [ 3.6113e-01, 1.0381e-01, -1.1599e+00, ..., -1.0439e+00,\n", - " -1.0221e+00, -1.0208e+00],\n", - " [-1.2717e+00, -2.1460e+00, -7.5677e-01, ..., -9.7822e-01,\n", - " -9.3785e-01, -1.0371e+00]],\n", - "\n", - " [[-1.5465e+00, -1.0152e+00, -8.8901e-01, ..., -4.8522e-01,\n", - " -7.5163e-01, -6.7765e-01],\n", - " [-7.6101e-01, -7.3352e-01, -9.1588e-01, ..., -2.4836e-01,\n", - " -5.8927e-01, -7.3723e-01],\n", - " [-2.4714e-02, 1.7016e-01, -4.2326e-01, ..., -3.3204e-01,\n", - " -7.6696e-01, -7.1652e-01],\n", - " ...,\n", - " [-1.7032e+00, -1.2591e+00, -1.1449e+00, ..., -1.1810e+00,\n", - " -1.1163e+00, -9.3108e-01],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[ 6.4983e-01, 2.6117e-01, -8.4197e-01, ..., -8.7213e-01,\n", - " -1.1073e+00, -1.3253e+00],\n", - " [ 3.5391e-01, -1.5846e-02, -4.0425e-01, ..., -9.9173e-01,\n", - " -1.0727e+00, -1.1924e+00],\n", - " [ 3.7704e-01, -6.2785e-02, -1.1468e-01, ..., -1.1021e+00,\n", - " -1.0952e+00, -1.1182e+00],\n", - " ...,\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]],\n", - "\n", - " [[ 4.4458e-02, -1.7547e-01, -6.7475e-01, ..., -4.9801e-01,\n", - " -5.6783e-01, -7.7852e-01],\n", - " [-1.3428e+00, -8.0343e-01, -9.0457e-01, ..., -6.5902e-01,\n", - " -7.2550e-01, -6.2796e-01],\n", - " [-7.6253e-01, -1.3071e-01, -1.3280e-01, ..., -5.6133e-01,\n", - " -6.0588e-01, -7.2115e-01],\n", - " ...,\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]],\n", - "\n", - " [[-1.0798e+00, -1.0834e+00, -1.1797e+00, ..., -1.7757e-01,\n", - " -4.3747e-01, -4.0007e-02],\n", - " [ 9.2354e-01, 6.3771e-01, -5.2810e-01, ..., -1.2928e-01,\n", - " -2.0342e-01, 1.6656e-01],\n", - " [ 4.9337e-01, -9.1133e-03, -7.3302e-01, ..., 1.0074e-01,\n", - " -9.8115e-02, -9.2357e-03],\n", - " ...,\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00],\n", - " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", - " -3.9869e+00, -3.6797e+00]]], device='cuda:0')\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "print(xs)" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "505ca294", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[ True, True, True, ..., True, True, True]],\n", - "\n", - " [[ True, True, True, ..., True, True, True]],\n", - "\n", - " [[ True, True, True, ..., True, False, False]],\n", - "\n", - " ...,\n", - "\n", - " [[ True, True, True, ..., False, False, False]],\n", - "\n", - " [[ True, True, True, ..., False, False, False]],\n", - "\n", - " [[ True, True, True, ..., False, False, False]]], device='cuda:0')\n" - ] - } - ], - "source": [ - "from wenet.utils.mask import make_pad_mask\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", - "print(masks)" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "aa03c2b9", - "metadata": {}, - "outputs": [], - "source": [ - "xs, pos_emb, masks = model.encoder.embed(xs, masks)" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "ebc0ea12", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[-0.5482, 2.2866, -1.0750, ..., 1.4504, 0.2895, -0.6945],\n", - " [-0.8013, 1.7688, -1.6639, ..., 1.8332, 0.6791, -0.2000],\n", - " [-1.7112, 2.7057, -1.3363, ..., 1.2336, 0.1870, -0.5735],\n", - " ...,\n", - " [-0.9697, 2.3129, -0.8752, ..., 0.8584, 0.4853, -0.4177],\n", - " [-1.3609, 2.1779, -1.7813, ..., 2.0928, 0.2528, -0.3650],\n", - " [-1.6967, 2.3544, -1.7417, ..., 1.3670, 0.5951, -0.7415]],\n", - "\n", - " [[-1.9828, 2.3178, -0.9079, ..., 0.4117, 0.5006, 0.0872],\n", - " [-0.7640, 1.3558, -1.3613, ..., 0.7317, 0.6784, 0.1685],\n", - " [-0.9504, 1.6038, -1.3030, ..., 0.5754, 0.2677, 0.3343],\n", - " ...,\n", - " [-1.4757, 2.5317, -1.2321, ..., 1.2997, 0.5019, -0.1034],\n", - " [-1.1731, 2.3172, -1.2542, ..., 1.7391, 0.2171, -0.4445],\n", - " [-1.2700, 3.2229, -0.8872, ..., 1.6461, 0.0973, -0.7679]],\n", - "\n", - " [[-0.5873, 1.4291, -1.3950, ..., 0.2102, 0.1027, 0.0918],\n", - " [ 0.1743, 1.7834, -1.6422, ..., 0.8113, 0.3137, 0.5634],\n", - " [-0.3492, 1.8310, -1.0685, ..., 0.6924, 0.1378, 0.4594],\n", - " ...,\n", - " [-1.0869, 2.3002, -1.2638, ..., 1.7998, 0.5134, -0.5223],\n", - " [-1.2614, 2.7240, -1.3734, ..., 1.4445, 0.5742, -0.3320],\n", - " [-2.2068, 4.3462, -3.8289, ..., 2.1426, 1.2034, -1.3795]],\n", - "\n", - " ...,\n", - "\n", - " [[-0.3914, 1.8553, -0.5747, ..., 1.0062, 0.4632, -1.0452],\n", - " [-0.8605, 2.0172, -1.4437, ..., 1.4526, 0.1657, 0.5923],\n", - " [-0.7307, 2.2841, -1.0699, ..., 1.5825, -0.0980, 0.5503],\n", - " ...,\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n", - "\n", - " [[-0.1619, 0.6255, -1.1323, ..., 0.0724, -0.2204, 0.4636],\n", - " [-0.0831, 0.5750, -1.0930, ..., 0.9110, -0.0650, 0.7299],\n", - " [-0.2820, 0.0801, -0.9418, ..., 0.3379, -0.1166, 0.4451],\n", - " ...,\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n", - "\n", - " [[-0.5458, -0.6909, -1.3597, ..., -0.7818, 0.6875, 0.9843],\n", - " [ 0.0421, -1.1062, -1.4389, ..., -0.0239, 0.9115, 0.5287],\n", - " [-0.2909, -0.1886, -1.5487, ..., -0.1392, 0.0580, 0.3066],\n", - " ...,\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", - " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]]],\n", - " device='cuda:0', grad_fn=)\n", - "tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,\n", - " 0.0000e+00, 1.0000e+00],\n", - " [ 8.4147e-01, 5.4030e-01, 8.0196e-01, ..., 1.0000e+00,\n", - " 1.0746e-04, 1.0000e+00],\n", - " [ 9.0930e-01, -4.1615e-01, 9.5814e-01, ..., 1.0000e+00,\n", - " 2.1492e-04, 1.0000e+00],\n", - " ...,\n", - " [-7.6825e-01, -6.4014e-01, 6.3280e-01, ..., 9.9998e-01,\n", - " 5.1581e-03, 9.9999e-01],\n", - " [-9.5375e-01, 3.0059e-01, 9.9899e-01, ..., 9.9998e-01,\n", - " 5.2656e-03, 9.9999e-01],\n", - " [-2.6237e-01, 9.6497e-01, 5.6075e-01, ..., 9.9998e-01,\n", - " 5.3730e-03, 9.9999e-01]]], device='cuda:0')\n", - "tensor([[[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, False, False, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, False, False, False, False, False, False, False, False,\n", - " False]],\n", - "\n", - " [[ True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, True, True, True, True, True, True, True, True, True,\n", - " True, False, False, False, False, False, False, False, False, False,\n", - " False]]], device='cuda:0')\n", - "torch.Size([16, 1, 51])\n" - ] - } - ], - "source": [ - "print(xs)\n", - "print(pos_emb)\n", - "print(masks)\n", - "print(masks.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "id": "4289461b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n", - " -0.6945408 ]\n", - " [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n", - " -0.1999542 ]\n", - " [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n", - " -0.5735198 ]\n", - " ...\n", - " [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n", - " -0.41773027]\n", - " [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n", - " -0.36496443]\n", - " [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n", - " -0.74147725]]\n", - "\n", - " [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n", - " 0.08721463]\n", - " [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n", - " 0.16851945]\n", - " [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n", - " 0.33433008]\n", - " ...\n", - " [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n", - " -0.10343577]\n", - " [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n", - " -0.44447583]\n", - " [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n", - " -0.7678688 ]]\n", - "\n", - " [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n", - " 0.09179455]\n", - " [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n", - " 0.56344515]\n", - " [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n", - " 0.45937473]\n", - " ...\n", - " [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n", - " -0.52227837]\n", - " [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n", - " -0.33201432]\n", - " [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n", - " -1.3795122 ]]\n", - "\n", - " ...\n", - "\n", - " [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n", - " -1.045236 ]\n", - " [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n", - " 0.5923172 ]\n", - " [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n", - " 0.55030036]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n", - " 0.46362036]\n", - " [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n", - " 0.72986233]\n", - " [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n", - " 0.44514441]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n", - " 0.9842716 ]\n", - " [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n", - " 0.52870303]\n", - " [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n", - " 0.30663735]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]]\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n", - "print(xs.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "67e10d73", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 2.0908e-03],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 1.1943e-02, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 4.6105e-02, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 9.6723e-03,\n", - " 4.6135e-02, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[2.2816e-01, 2.4615e-01, 2.5304e-01, ..., 2.0402e-01,\n", - " 2.3248e-01, 3.1191e-01],\n", - " [1.3587e-01, 2.8877e-01, 2.7991e-01, ..., 1.9210e-01,\n", - " 2.0346e-01, 1.9934e-01],\n", - " [2.5739e-01, 3.9348e-01, 2.7877e-01, ..., 2.7483e-01,\n", - " 1.9302e-01, 2.3810e-01],\n", - " ...,\n", - " [1.1939e-01, 2.8473e-01, 3.3082e-01, ..., 2.3838e-01,\n", - " 2.2104e-01, 2.3906e-01],\n", - " [1.7388e-01, 2.0402e-01, 4.0263e-01, ..., 2.4782e-01,\n", - " 2.6742e-01, 1.5427e-01],\n", - " [0.0000e+00, 2.9081e-01, 2.7726e-01, ..., 1.7540e-01,\n", - " 1.8479e-01, 2.2483e-01]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.5447e-01, 3.8861e-01, 3.9724e-01, ..., 3.8680e-01,\n", - " 3.3568e-01, 3.4552e-01],\n", - " [4.1739e-01, 5.1039e-01, 4.1730e-01, ..., 3.3993e-01,\n", - " 3.7082e-01, 3.5110e-01],\n", - " [3.6117e-01, 4.0745e-01, 4.8491e-01, ..., 3.4849e-01,\n", - " 3.2321e-01, 3.5189e-01],\n", - " ...,\n", - " [2.3144e-01, 3.8021e-01, 5.1526e-01, ..., 3.6499e-01,\n", - " 3.7412e-01, 3.9986e-01],\n", - " [3.4679e-01, 4.0238e-01, 5.0077e-01, ..., 3.6185e-01,\n", - " 3.1597e-01, 3.6335e-01],\n", - " [3.6498e-01, 3.7943e-01, 5.1719e-01, ..., 3.1798e-01,\n", - " 3.3657e-01, 3.4130e-01]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.4560e-02, 9.4475e-02, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5002e-02, 2.9632e-02, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [3.2952e-02, 0.0000e+00, 0.0000e+00, ..., 4.5850e-02,\n", - " 2.0439e-02, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 4.4258e-02],\n", - " [0.0000e+00, 0.0000e+00, 2.5565e-02, ..., 0.0000e+00,\n", - " 9.0044e-03, 4.9084e-02]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.1141e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[3.3697e-01, 3.8527e-01, 3.2900e-01, ..., 2.8704e-01,\n", - " 2.3351e-01, 1.9004e-01],\n", - " [1.3575e-01, 3.5783e-01, 3.3573e-01, ..., 2.2082e-01,\n", - " 1.5855e-01, 1.3587e-01],\n", - " [2.1929e-01, 2.8900e-01, 2.8255e-01, ..., 2.0603e-01,\n", - " 2.3927e-01, 2.1909e-01],\n", - " ...,\n", - " [2.3292e-01, 3.9097e-01, 3.6399e-01, ..., 2.0598e-01,\n", - " 2.5374e-01, 2.3137e-01],\n", - " [1.8739e-01, 3.0794e-01, 3.0297e-01, ..., 2.7251e-01,\n", - " 2.5192e-01, 2.0837e-01],\n", - " [2.2454e-01, 4.1402e-01, 5.4083e-01, ..., 3.1875e-01,\n", - " 2.5080e-01, 2.5939e-01]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.6457e-01, 4.9519e-01, 5.6702e-01, ..., 3.0955e-01,\n", - " 3.5292e-01, 3.2669e-01],\n", - " [2.1577e-01, 5.1833e-01, 4.9183e-01, ..., 3.6043e-01,\n", - " 3.8524e-01, 3.6155e-01],\n", - " [2.0068e-01, 4.2784e-01, 5.2818e-01, ..., 3.1871e-01,\n", - " 3.2452e-01, 3.1036e-01],\n", - " ...,\n", - " [4.9855e-01, 5.1001e-01, 5.2279e-01, ..., 3.6450e-01,\n", - " 3.4338e-01, 3.3603e-01],\n", - " [4.1233e-01, 5.5518e-01, 5.2828e-01, ..., 4.0676e-01,\n", - " 3.3873e-01, 3.6724e-01],\n", - " [4.0820e-01, 4.6187e-01, 4.7338e-01, ..., 3.8691e-01,\n", - " 3.6039e-01, 3.8022e-01]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 5.7852e-03, 0.0000e+00, ..., 7.4838e-03,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.0351e-02,\n", - " 0.0000e+00, 2.6720e-04],\n", - " [9.4807e-04, 0.0000e+00, 0.0000e+00, ..., 7.9551e-03,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [2.0326e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 1.0801e-02, 0.0000e+00],\n", - " [1.8470e-01, 0.0000e+00, 0.0000e+00, ..., 5.0584e-02,\n", - " 9.4758e-02, 5.9146e-02]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[3.8708e-01, 2.8022e-01, 3.5893e-01, ..., 1.6595e-01,\n", - " 1.6031e-01, 2.1136e-01],\n", - " [1.5595e-01, 3.0544e-01, 2.4666e-01, ..., 2.2675e-01,\n", - " 2.5765e-01, 1.9682e-01],\n", - " [2.9518e-01, 4.1210e-01, 2.0063e-01, ..., 1.7595e-01,\n", - " 2.2537e-01, 2.2214e-01],\n", - " ...,\n", - " [2.4745e-01, 2.6259e-01, 3.8654e-01, ..., 2.3620e-01,\n", - " 2.3157e-01, 1.8514e-01],\n", - " [2.5715e-01, 2.9593e-01, 4.7745e-01, ..., 2.3546e-01,\n", - " 2.5073e-01, 2.0976e-01],\n", - " [1.2015e+00, 8.4644e-01, 7.3386e-01, ..., 1.0252e+00,\n", - " 9.5310e-01, 1.0013e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[4.5013e-01, 4.7484e-01, 4.0540e-01, ..., 1.9346e-01,\n", - " 1.7826e-01, 1.4777e-01],\n", - " [4.7546e-01, 4.8187e-01, 3.6760e-01, ..., 2.7809e-01,\n", - " 3.2997e-01, 3.2337e-01],\n", - " [4.6160e-01, 4.0050e-01, 3.9061e-01, ..., 3.6613e-01,\n", - " 3.5243e-01, 2.9739e-01],\n", - " ...,\n", - " [5.5148e-01, 5.1018e-01, 4.0132e-01, ..., 3.8948e-01,\n", - " 3.5737e-01, 3.3088e-01],\n", - " [4.1973e-01, 4.5475e-01, 4.5320e-01, ..., 3.8343e-01,\n", - " 4.0126e-01, 3.6181e-01],\n", - " [3.4280e-01, 3.1606e-01, 4.4701e-01, ..., 2.1665e-01,\n", - " 2.3985e-01, 2.3903e-01]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [4.1783e-02, 0.0000e+00, 1.5805e-02, ..., 0.0000e+00,\n", - " 2.2508e-02, 0.0000e+00],\n", - " [4.3234e-02, 7.7864e-02, 0.0000e+00, ..., 1.6347e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.2092e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.3563e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[0.0000e+00, 2.5187e-01, 2.4979e-01, ..., 2.4775e-01,\n", - " 2.2354e-01, 1.9149e-01],\n", - " [1.6541e-01, 1.9586e-01, 1.9813e-01, ..., 2.7344e-01,\n", - " 2.0928e-01, 2.6150e-01],\n", - " [1.0495e-01, 6.3299e-02, 3.3844e-01, ..., 2.5138e-01,\n", - " 1.2470e-01, 2.3927e-01],\n", - " ...,\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.1428e-01, 4.5667e-01, 4.6821e-01, ..., 3.2058e-01,\n", - " 3.3579e-01, 3.9013e-01],\n", - " [1.0441e-01, 4.5739e-01, 4.6107e-01, ..., 3.8468e-01,\n", - " 3.8291e-01, 3.6686e-01],\n", - " [1.9868e-01, 3.5520e-01, 4.4313e-01, ..., 4.0679e-01,\n", - " 3.8068e-01, 3.0646e-01],\n", - " ...,\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.4654e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 3.3902e-02],\n", - " [0.0000e+00, 0.0000e+00, 1.8307e-02, ..., 5.1669e-02,\n", - " 9.4838e-03, 7.4535e-02],\n", - " [9.9215e-02, 0.0000e+00, 1.5872e-02, ..., 1.6203e-02,\n", - " 5.1401e-02, 1.9239e-03],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[4.0034e-01, 2.5306e-01, 2.0218e-01, ..., 9.8162e-02,\n", - " 7.0643e-02, 4.9741e-02],\n", - " [1.2568e-01, 2.1031e-01, 1.1182e-01, ..., 4.2781e-02,\n", - " 1.1969e-01, 1.2005e-01],\n", - " [2.8787e-01, 2.4031e-01, 2.2566e-01, ..., 0.0000e+00,\n", - " 6.4181e-02, 5.8730e-02],\n", - " ...,\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.8405e-01, 3.0990e-01, 3.7156e-01, ..., 1.8125e-01,\n", - " 1.5051e-01, 1.9620e-01],\n", - " [4.7286e-01, 4.0529e-01, 3.9718e-01, ..., 2.4710e-01,\n", - " 4.5657e-02, 1.1501e-01],\n", - " [3.2621e-01, 3.0073e-01, 3.0477e-01, ..., 2.3529e-01,\n", - " 2.1357e-01, 1.6986e-01],\n", - " ...,\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00]]],\n", - "\n", - "\n", - " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[3.3438e-02, 1.2378e-03, 5.2972e-02, ..., 7.2712e-02,\n", - " 8.6563e-02, 1.4494e-01],\n", - " [1.1043e-01, 6.1431e-02, 6.3630e-02, ..., 8.1278e-02,\n", - " 6.2590e-02, 8.3154e-02],\n", - " [1.7677e-02, 2.0111e-03, 7.8750e-02, ..., 6.9633e-02,\n", - " 8.9799e-02, 5.3263e-02],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[1.0034e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.5627e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [5.1447e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 4.3641e-03],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " ...,\n", - "\n", - " [[2.5142e-01, 4.5964e-01, 3.7346e-01, ..., 4.7631e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.9760e-01, 2.6627e-01, 1.1191e-01, ..., 3.0450e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [1.6341e-01, 3.2938e-01, 2.5690e-01, ..., 5.5694e-02,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00],\n", - " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", - " 1.0094e+00, 1.0221e+00]],\n", - "\n", - " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 2.2189e-02, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 2.8490e-02],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " ...,\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00],\n", - " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", - " 0.0000e+00, 0.0000e+00]],\n", - "\n", - " [[2.5810e-01, 6.3017e-01, 3.7038e-01, ..., 1.8704e-01,\n", - " 8.2694e-02, 9.9127e-02],\n", - " [1.7293e-01, 5.0679e-01, 4.0739e-01, ..., 1.6006e-01,\n", - " 1.1725e-01, 9.9405e-02],\n", - " [2.4175e-01, 4.1616e-01, 4.1257e-01, ..., 1.3520e-01,\n", - " 7.9126e-02, 1.2846e-01],\n", - " ...,\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00],\n", - " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", - " 1.2189e+00, 1.2380e+00]]]], device='cuda:0',\n", - " grad_fn=)\n" - ] - } - ], - "source": [ - "xs = model.encoder.global_cmvn(feat)\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", - "\n", - "x = xs.unsqueeze(1)\n", - "x = model.encoder.embed.conv(x)\n", - "print(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "9a9478ad", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[-0.03426375 0.14291267 -0.06718873 ... 0.09064753 0.01809387\n", - " -0.0434088 ]\n", - " [-0.05007839 0.11054724 -0.10399298 ... 0.11457238 0.04244684\n", - " -0.01249714]\n", - " [-0.10695291 0.16910909 -0.08352133 ... 0.07710276 0.01168563\n", - " -0.03584499]\n", - " ...\n", - " [-0.06060536 0.14455931 -0.05470302 ... 0.05364908 0.03033342\n", - " -0.02610814]\n", - " [-0.08505894 0.13611752 -0.11132983 ... 0.13079923 0.01580139\n", - " -0.02281028]\n", - " [-0.10604677 0.14714901 -0.10885533 ... 0.08543444 0.03719445\n", - " -0.04634233]]\n", - "\n", - " [[-0.12392755 0.14486063 -0.05674079 ... 0.02573164 0.03128851\n", - " 0.00545091]\n", - " [-0.04775286 0.08473608 -0.08507854 ... 0.04573154 0.04240163\n", - " 0.01053247]\n", - " [-0.05940291 0.10023535 -0.0814373 ... 0.035965 0.01673085\n", - " 0.02089563]\n", - " ...\n", - " [-0.09222981 0.15823206 -0.07700447 ... 0.08122957 0.03136991\n", - " -0.00646474]\n", - " [-0.07331756 0.14482647 -0.07838815 ... 0.1086944 0.01356864\n", - " -0.02777974]\n", - " [-0.07937264 0.20143102 -0.05544947 ... 0.10287814 0.00608235\n", - " -0.0479918 ]]\n", - "\n", - " [[-0.03670349 0.0893159 -0.08718812 ... 0.0131405 0.00642052\n", - " 0.00573716]\n", - " [ 0.01089254 0.11146393 -0.10263617 ... 0.05070438 0.01960694\n", - " 0.03521532]\n", - " [-0.0218228 0.11443964 -0.06678198 ... 0.04327708 0.00861394\n", - " 0.02871092]\n", - " ...\n", - " [-0.06792898 0.14376275 -0.07899005 ... 0.11248926 0.03208683\n", - " -0.0326424 ]\n", - " [-0.07884051 0.17024788 -0.08583611 ... 0.09028331 0.03588808\n", - " -0.0207509 ]\n", - " [-0.13792302 0.27163863 -0.23930418 ... 0.13391261 0.0752104\n", - " -0.08621951]]\n", - "\n", - " ...\n", - "\n", - " [[-0.02446348 0.11595841 -0.03591986 ... 0.0628897 0.02895011\n", - " -0.06532725]\n", - " [-0.05378424 0.1260737 -0.09023033 ... 0.09078894 0.01035743\n", - " 0.03701983]\n", - " [-0.04566649 0.14275314 -0.0668687 ... 0.09890588 -0.00612222\n", - " 0.03439377]\n", - " ...\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]]\n", - "\n", - " [[-0.01012144 0.03909408 -0.07077143 ... 0.00452683 -0.01377654\n", - " 0.02897627]\n", - " [-0.00519154 0.03594019 -0.06831125 ... 0.05693541 -0.00406374\n", - " 0.0456164 ]\n", - " [-0.01762631 0.00500899 -0.05886075 ... 0.02112178 -0.00729015\n", - " 0.02782153]\n", - " ...\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]]\n", - "\n", - " [[-0.03411558 -0.04318277 -0.08497842 ... -0.04886402 0.04296734\n", - " 0.06151697]\n", - " [ 0.00263296 -0.06913657 -0.08993219 ... -0.00149064 0.05696633\n", - " 0.03304394]\n", - " [-0.01818341 -0.0117864 -0.09679577 ... -0.00870231 0.00362198\n", - " 0.01916483]\n", - " ...\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]\n", - " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", - " -0.18293698]]]\n", - "torch.Size([16, 51, 256])\n" - ] - } - ], - "source": [ - "b, c, t, f = x.size()\n", - "x = model.encoder.embed.out(x.transpose(1, 2).contiguous().view(b, t, c * f))\n", - "print(x.cpu().detach().numpy())\n", - "print(x.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "fd69003f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n", - " -0.6945408 ]\n", - " [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n", - " -0.1999542 ]\n", - " [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n", - " -0.5735198 ]\n", - " ...\n", - " [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n", - " -0.41773027]\n", - " [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n", - " -0.36496443]\n", - " [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n", - " -0.74147725]]\n", - "\n", - " [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n", - " 0.08721463]\n", - " [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n", - " 0.16851945]\n", - " [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n", - " 0.33433008]\n", - " ...\n", - " [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n", - " -0.10343577]\n", - " [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n", - " -0.44447583]\n", - " [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n", - " -0.7678688 ]]\n", - "\n", - " [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n", - " 0.09179455]\n", - " [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n", - " 0.56344515]\n", - " [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n", - " 0.45937473]\n", - " ...\n", - " [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n", - " -0.52227837]\n", - " [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n", - " -0.33201432]\n", - " [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n", - " -1.3795122 ]]\n", - "\n", - " ...\n", - "\n", - " [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n", - " -1.045236 ]\n", - " [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n", - " 0.5923172 ]\n", - " [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n", - " 0.55030036]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n", - " 0.46362036]\n", - " [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n", - " 0.72986233]\n", - " [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n", - " 0.44514441]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]\n", - "\n", - " [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n", - " 0.9842716 ]\n", - " [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n", - " 0.52870303]\n", - " [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n", - " 0.30663735]\n", - " ...\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]\n", - " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", - " -2.9269917 ]]]\n" - ] - } - ], - "source": [ - "x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n", - "print(x.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "8ed88489", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.float32\n", - "[[[ 0.0000000e+00 1.0000000e+00 0.0000000e+00 ... 1.0000000e+00\n", - " 0.0000000e+00 1.0000000e+00]\n", - " [ 8.4147096e-01 5.4030234e-01 8.0196178e-01 ... 1.0000000e+00\n", - " 1.0746076e-04 1.0000000e+00]\n", - " [ 9.0929741e-01 -4.1614684e-01 9.5814437e-01 ... 1.0000000e+00\n", - " 2.1492151e-04 1.0000000e+00]\n", - " ...\n", - " [-7.6825464e-01 -6.4014435e-01 6.3279724e-01 ... 9.9998462e-01\n", - " 5.1580933e-03 9.9998671e-01]\n", - " [-9.5375264e-01 3.0059254e-01 9.9899054e-01 ... 9.9998397e-01\n", - " 5.2655530e-03 9.9998611e-01]\n", - " [-2.6237485e-01 9.6496606e-01 5.6074661e-01 ... 9.9998331e-01\n", - " 5.3730118e-03 9.9998558e-01]]]\n" - ] - } - ], - "source": [ - "print(pos_emb.dtype)\n", - "print(pos_emb.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "id": "5e277881", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([16, 51, 256])\n" - ] - }, - { - "ename": "NameError", - "evalue": "name 'mask' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0mpos_emb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpos_emb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 144\u001b[0m \u001b[0mx_att\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m )\n", - "\u001b[0;31mNameError\u001b[0m: name 'mask' is not defined" - ] - } - ], - "source": [ - "def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,\n", - " use_dynamic_chunk: bool,\n", - " use_dynamic_left_chunk: bool,\n", - " decoding_chunk_size: int, static_chunk_size: int,\n", - " num_decoding_left_chunks: int):\n", - " \"\"\" Apply optional mask for encoder.\n", - " Args:\n", - " xs (torch.Tensor): padded input, (B, L, D), L for max length\n", - " mask (torch.Tensor): mask for xs, (B, 1, L)\n", - " use_dynamic_chunk (bool): whether to use dynamic chunk or not\n", - " use_dynamic_left_chunk (bool): whether to use dynamic left chunk for\n", - " training.\n", - " decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's\n", - " 0: default for training, use random dynamic chunk.\n", - " <0: for decoding, use full chunk.\n", - " >0: for decoding, use fixed chunk size as set.\n", - " static_chunk_size (int): chunk size for static chunk training/decoding\n", - " if it's greater than 0, if use_dynamic_chunk is true,\n", - " this parameter will be ignored\n", - " num_decoding_left_chunks: number of left chunks, this is for decoding,\n", - " the chunk size is decoding_chunk_size.\n", - " >=0: use num_decoding_left_chunks\n", - " <0: use all left chunks\n", - " Returns:\n", - " torch.Tensor: chunk mask of the input xs.\n", - " \"\"\"\n", - " # Whether to use chunk mask or not\n", - " if use_dynamic_chunk:\n", - " max_len = xs.size(1)\n", - " if decoding_chunk_size < 0:\n", - " chunk_size = max_len\n", - " num_left_chunks = -1\n", - " elif decoding_chunk_size > 0:\n", - " chunk_size = decoding_chunk_size\n", - " num_left_chunks = num_decoding_left_chunks\n", - " else:\n", - " # chunk size is either [1, 25] or full context(max_len).\n", - " # Since we use 4 times subsampling and allow up to 1s(100 frames)\n", - " # delay, the maximum frame is 100 / 4 = 25.\n", - " chunk_size = torch.randint(1, max_len, (1, )).item()\n", - " num_left_chunks = -1\n", - " if chunk_size > max_len // 2:\n", - " chunk_size = max_len\n", - " else:\n", - " chunk_size = chunk_size % 25 + 1\n", - " if use_dynamic_left_chunk:\n", - " max_left_chunks = (max_len - 1) // chunk_size\n", - " num_left_chunks = torch.randint(0, max_left_chunks,\n", - " (1, )).item()\n", - " chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,\n", - " num_left_chunks,\n", - " xs.device) # (L, L)\n", - " chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n", - " chunk_masks = masks & chunk_masks # (B, L, L)\n", - " elif static_chunk_size > 0:\n", - " num_left_chunks = num_decoding_left_chunks\n", - " chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,\n", - " num_left_chunks,\n", - " xs.device) # (L, L)\n", - " chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n", - " chunk_masks = masks & chunk_masks # (B, L, L)\n", - " else:\n", - " chunk_masks = masks\n", - " return chunk_masks\n", - "\n", - "from wenet.utils.mask import make_pad_mask\n", - "\n", - "\n", - "masks = ~make_pad_mask(feat_len).unsqueeze(1)\n", - "xs = model.encoder.global_cmvn(feat)\n", - "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n", - "\n", - "mask_pad = masks\n", - "decoding_chunk_size=0\n", - "num_decoding_left_chunks=-1\n", - "use_dynamic_left_chunk=-1\n", - "use_dynamic_chunk=False\n", - "static_chunk_size=-1\n", - "chunk_masks = add_optional_chunk_mask(\n", - " xs, \n", - " masks, \n", - " use_dynamic_chunk,\n", - " use_dynamic_left_chunk,\n", - " decoding_chunk_size, \n", - " static_chunk_size,\n", - " num_decoding_left_chunks)\n", - "\n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_embed', \n", - " embed_out=xs.cpu().detach().numpy(), \n", - " pos_emb=pos_emb.cpu().detach().numpy(),\n", - " chunk_masks=chunk_masks.cpu().detach().numpy(),\n", - " mask_pad=mask_pad.cpu().detach().numpy())\n", - "\n", - "model.eval()\n", - "# print(chunk_masks)\n", - "print(xs.shape)\n", - "for layer in model.encoder.encoders:\n", - " #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " #np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0', enc_0=xs.cpu().detach().numpy())\n", - " \n", - " x = xs\n", - " residual = x\n", - " x_norm = layer.norm_ff_macaron(x)\n", - " !rm /workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff.npz\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff', \n", - " norm_ff=x_norm.cpu().detach().numpy(),\n", - " xs=xs.cpu().detach().numpy())\n", - " #print(x.cpu().detach().numpy())\n", - " for p in layer.norm_ff_macaron.parameters():\n", - " #print(p, p.sum())\n", - " pass\n", - " \n", - " x = residual + layer.ff_scale * layer.feed_forward_macaron(x_norm)\n", - " \n", - " ps = []\n", - " for n, p in layer.feed_forward_macaron.state_dict().items():\n", - " #print(n, p.cpu().data.numpy())\n", - " ps.append(p.cpu().data.numpy())\n", - " pass\n", - "\n", - " ff_l_x = layer.feed_forward_macaron.w_1(x_norm)\n", - " ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n", - " ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_ff_out', \n", - " norm_ff=x_norm.cpu().detach().numpy(),\n", - " ff_out=x.cpu().detach().numpy(),\n", - " ff_l_x = ff_l_x.cpu().detach().numpy(),\n", - " ff_l_a_x=ff_l_a_x.cpu().detach().numpy(),\n", - " ff_l_a_l_x=ff_l_a_l_x.cpu().detach().numpy(),\n", - " ps=ps,\n", - " )\n", - " \n", - " \n", - " residual = x\n", - " x = layer.norm_mha(x)\n", - " x_q = x\n", - " \n", - " x_att = layer.self_attn(x_q, x, x, pos_emb, masks)\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_selattn_out', \n", - " x_q=x_q.cpu().detach().numpy(),\n", - " x=x.cpu().detach().numpy(),\n", - " pos_emb = pos_emb.cpu().detach().numpy(),\n", - " mask=mask.cpu().detach().numpy(),\n", - " x_att=x_att.cpu().detach().numpy(),\n", - " )\n", - " \n", - " break\n", - "#print(xs.cpu().detach().numpy())\n", - "\n", - "\n", - "i = 0\n", - "for layer in model.encoder.encoders:\n", - " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", - " i += 1\n", - " if i == 2:\n", - " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_2', enc_2=xs.cpu().detach().numpy())\n", - " \n", - "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_all', enc_all=xs.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c43fd4f1", - "metadata": {}, - "outputs": [], - "source": [ - "out, mask = model.encoder(feat, feat_len)\n", - "#print(out.cpu().detach().numpy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0e73db22", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8f506114", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}