From dc0c5351ecc4f1d06195d415477fa231561c1a00 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 31 May 2021 07:45:43 +0000 Subject: [PATCH] add test notebook --- .notebook/WarmupLR.ipynb | 339 ++ ...l.ipynb => u2_confermer_model_wenet.ipynb} | 0 .notebook/u2_tansformer_model_espnet.ipynb | 1672 ++++++ .notebook/wenet_model.ipynb | 5015 +++++++++++++++++ 4 files changed, 7026 insertions(+) create mode 100644 .notebook/WarmupLR.ipynb rename .notebook/{u2_model.ipynb => u2_confermer_model_wenet.ipynb} (100%) create mode 100644 .notebook/u2_tansformer_model_espnet.ipynb create mode 100644 .notebook/wenet_model.ipynb diff --git a/.notebook/WarmupLR.ipynb b/.notebook/WarmupLR.ipynb new file mode 100644 index 00000000..21abf9cb --- /dev/null +++ b/.notebook/WarmupLR.ipynb @@ -0,0 +1,339 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "d6a0e098", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Union\n", + "\n", + "import torch\n", + "from torch.optim.lr_scheduler import _LRScheduler\n", + "\n", + "from typeguard import check_argument_types\n", + "\n", + "\n", + "class WarmupLR(_LRScheduler):\n", + " \"\"\"The WarmupLR scheduler\n", + " This scheduler is almost same as NoamLR Scheduler except for following\n", + " difference:\n", + " NoamLR:\n", + " lr = optimizer.lr * model_size ** -0.5\n", + " * min(step ** -0.5, step * warmup_step ** -1.5)\n", + " WarmupLR:\n", + " lr = optimizer.lr * warmup_step ** 0.5\n", + " * min(step ** -0.5, step * warmup_step ** -1.5)\n", + " Note that the maximum lr equals to optimizer.lr in this scheduler.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " optimizer: torch.optim.Optimizer,\n", + " warmup_steps: Union[int, float] = 25000,\n", + " last_epoch: int = -1,\n", + " ):\n", + " assert check_argument_types()\n", + " self.warmup_steps = warmup_steps\n", + "\n", + " # __init__() must be invoked before setting field\n", + " # because step() is also invoked in __init__()\n", + " super().__init__(optimizer, last_epoch)\n", + "\n", + " def __repr__(self):\n", + " return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n", + "\n", + " def get_lr(self):\n", + " step_num = self.last_epoch + 1\n", + " return [\n", + " lr\n", + " * self.warmup_steps ** 0.5\n", + " * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)\n", + " for lr in self.base_lrs\n", + " ]\n", + "\n", + " def set_step(self, step: int):\n", + " self.last_epoch = step" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0d496677", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.optim as optim\n", + "model = torch.nn.Linear(10, 200)\n", + "optimizer = optim.Adam(model.parameters())\n", + "scheduler = WarmupLR(optimizer, warmup_steps=25000)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e3e3f3dc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 0.0 -1\n" + ] + } + ], + "source": [ + "infos = {}\n", + "start_epoch = infos.get('epoch', -1) + 1\n", + "cv_loss = infos.get('cv_loss', 0.0)\n", + "step = infos.get('step', -1)\n", + "print(start_epoch, cv_loss, step)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "dc3d550c", + "metadata": {}, + "outputs": [], + "source": [ + "scheduler.set_step(step)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e527634e", + "metadata": {}, + "outputs": [], + "source": [ + "lrs=[]\n", + "for i in range(100000):\n", + " scheduler.step()\n", + " lrs.append(scheduler.get_lr())" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f1452db9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting matplotlib\n", + " Downloading matplotlib-3.4.1-cp38-cp38-manylinux1_x86_64.whl (10.3 MB)\n", + "\u001b[K |████████████████████████████████| 10.3 MB 575 kB/s eta 0:00:01\n", + "\u001b[?25hCollecting kiwisolver>=1.0.1\n", + " Downloading kiwisolver-1.3.1-cp38-cp38-manylinux1_x86_64.whl (1.2 MB)\n", + "\u001b[K |████████████████████████████████| 1.2 MB 465 kB/s eta 0:00:01\n", + "\u001b[?25hRequirement already satisfied: pillow>=6.2.0 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (8.1.2)\n", + "Requirement already satisfied: numpy>=1.16 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (1.20.1)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.8.1)\n", + "Collecting cycler>=0.10\n", + " Downloading cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)\n", + "Requirement already satisfied: pyparsing>=2.2.1 in /workspace/wenet/venv/lib/python3.8/site-packages (from matplotlib) (2.4.7)\n", + "Requirement already satisfied: six in /workspace/wenet/venv/lib/python3.8/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n", + "Installing collected packages: kiwisolver, cycler, matplotlib\n", + "Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1\n" + ] + } + ], + "source": [ + "!pip install matplotlib\n", + "import matplotlib.pyplot as plt\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "0f36d04f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqc0lEQVR4nO3deXxV1b338c8vCUkYkkAghJAEAhLQIJMEHHFCBa2KVkG0T7Wt1qet9ra1w9Xn3ufe1ld7b21tvVq1alut+mhJQK3Yqjig1SpCDgIyBiLTSZhCAglTyLSeP86GxjTDQZKc6ft+vXh5zjrrrLM2O+bL3mvv3zHnHCIiIu2JC/UEREQkvCkoRESkQwoKERHpkIJCREQ6pKAQEZEOJYR6Al1h0KBBLi8vL9TTEBGJKMuXL9/rnMvorF9UBEVeXh4+ny/U0xARiShmti2Yfjr1JCIiHVJQiIhIhxQUIiLSIQWFiIh0SEEhIiIdCioozGymmZWaWZmZ3d3G60lmVuS9vtTM8lq8do/XXmpmM1q0P2lme8xsTaux0s3sTTPb5P13wElsn4iInKROg8LM4oFHgMuBAuBGMyto1e1WYJ9zbhTwAHCf994CYC4wFpgJPOqNB/BHr621u4G3nXP5wNvecxERCZFgjiimAmXOuc3OuXpgHjCrVZ9ZwNPe4wXAdDMzr32ec+6oc24LUOaNh3PuPaC6jc9rOdbTwDXBb450p82VB3m3dE+opyEiPSyYoMgG/C2el3ttbfZxzjUCNcDAIN/bWqZzbqf3eBeQ2VYnM7vdzHxm5qusrAxiM+Rk3fS7pXzlqRLeWrc71FMRkR4U1ovZLvCtSm1+s5Jz7gnnXKFzrjAjo9M70OUkle05wK7aOgC+V7SSTysPhnhGItJTggmKCiC3xfMcr63NPmaWAKQBVUG+t7XdZpbljZUF6FxHGCj2lZMQZ7xy53kkJsRx+zM+DtQ1hHpaItIDggmKEiDfzEaYWSKBxemFrfosBG7xHl8PLPaOBhYCc72rokYA+cCyTj6v5Vi3AC8HMUfpRg1Nzbz4cTnTTxvMuJw0Hr7pDLZWHeZ7RatobtZX6YpEu06DwltzuBNYBKwHip1za83sXjO72uv2B2CgmZUBd+FdqeScWwsUA+uA14E7nHNNAGb2J2AJMMbMys3sVm+snwOXmtkm4BLvuYTQ4g172HuwnjmFgYPDs08ZyL9/4TTeWr+bB9/eFOLZiUh3s8A//CNbYWGhU/XY7nPb0yV8Ul7Dh3dfTEJ84N8Wzjl+uOATFiwv58G5E5k1sbNrFEQk3JjZcudcYWf9wnoxW0JvT20d75RWct3knOMhAWBm/Oza0zlzRDo/nP8JJVvbutJZRKKBgkI6tODjcpqa3fHTTi0lJcTz+Jcnk5Pem68/42PL3kMhmKGIdDcFhbTLOcd8XzlT89IZMahvm33690nkqa9MIc6Mrz61jOpD9T08SxHpbgoKaVfJ1n1s2XuIOVP++WiipeED+/K7myezo6aOrz/j40h9Uw/NUER6goJC2lXs89MvKYErxg3ptO/k4ek8eMNEVmzfxzefW059Y3MPzFBEeoKCQtp0oK6Bv36yk6smZNEnMbivVr98XBY/u3Yc75ZW8oP5usdCJFoE9xtAYs5fP9nJkYamNhexO3Lj1GHsP9zAfa9vIK13L+6dNZZAfUgRiVQKCmlTkc9P/uB+TMztf8Lv/eaFp7D/cD2Pv7eZ/n168f3LxnT9BEWkxygo5J9s2n2AFdv38+9fOO1zHw3cffmp1Bxp4DeLy0iMj+Pb0/O7eJYi0lMUFPJPin1+EuKMayZ9/rutAzfkjaO+sZlfvbkRM7jzYoWFSCRSUMhnBAoAVnDJaZkM6pd0UmPFxxm/nD0BgPvf2AgoLEQikYJCPuPt9XuoOlTPnCk5XTJe67AwM+64aFSXjC0iPUNBIZ8x3+cnMzWJ8/O77sugjoWFA365qJT6xma+e0m+roYSiRAKCjlud20d75Tu4RsXnPKZAoBdIT7OuH/2BBLijAff3kTNkQb+48oC4uIUFiLhTkEhxy1YXk6z44TvnQhWfJxx33XjSUnuxZMfbOFAXSP3XTeuy0NJRLqWgkKAYwUA/UwdkU5eOwUAu0JcnPF/rzyNtN69eOCtjRyoa+A3N00iKSG+2z5TRE6O/iknACzbUs3WqsPc0E1HEy2ZGd+5JJ//vKqAN9bt5qtPlVCr798WCVsKCgGg2FfuFQDM6rHP/Oq5I/j1nAks21LN9b/9kIr9R3rss0UkeAoK4UBdA6+u3slVE4bSO7FnTwF98Ywcnv7aVHbur+PaRz5gTUVNj36+iHROQSH8xSsAeEMn3zvRXc4dNYgF3zyHXvFxzHl8Ce9s2BOSeYhI2xQUQlGJn9GZ/ZiQkxayOYwZksJL3zqHkRl9ufXpEp5ZshXnVKZcJBwoKGLcxt0HWOnfz5zC3JDfADc4NZmi28/mojGD+Y+X13LPi6s52qhvyxMJNQVFjCsu8dMr3rj2JAoAdqW+SQk8cXMhd1x0CvNK/Nz0u6Xsqa0L9bREYpqCIobVNzbz0opAAcCBJ1kAsCvFxxk/nHEqj9x0But21HLVw39npX9/qKclErMUFDFs8YbdgQKAPXDvxOfxhfFZvPDNc0iICyxyF5f4Qz0lkZikoIhhxb5yhqQmc/7orisA2NUKhqbyyrfPo3D4AH70wid8v3gVh+sbQz0tkZiioIhRu2rqeLd0D9dNziY+zAvzpfdN5Nlbz+Rfpufz4opyZj38AWV7DoR6WiIxQ0ERo174OFAAcPbk8Dzt1Fp8nHHXpaN55mtTqT5Uz1W/+YCXVpSHeloiMUFBEYOccxT7/JzZzQUAu8O0/Axe/c40xuWk8b2iVfxowSoOHdWpKJHupKCIQUu3VLOt6nDI7sQ+WZmpyTx/25ncedEo5i8v54qH3mf5tn2hnpZI1FJQxKBin5+UpAQuP73nCgB2tYT4OH4wYwxFt59NY5Nj9mMf8us3N9LQ1BzqqYlEnaCCwsxmmlmpmZWZ2d1tvJ5kZkXe60vNLK/Fa/d47aVmNqOzMc1supl9bGYrzezvZqYvWO5CtccKAE7s+QKA3WHqiHRe++40rpmUzUNvb2L2Y0vYsvdQqKclElU6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5m+BLznnJgLPA/9+Ulson/GXVTupa2juke+d6Cmpyb349ZyJPHzTJLbsPcQVD77PUx9soblZtaJEukIwRxRTgTLn3GbnXD0wD5jVqs8s4Gnv8QJgugUKB80C5jnnjjrntgBl3ngdjemAVO9xGrDj822atKXI52dMZgrjQ1gAsLtcOX4or393GmeOTOcnr6xjzuNL+LTyYKinJRLxggmKbKDlLbHlXlubfZxzjUANMLCD93Y05m3Aq2ZWDnwZ+HlbkzKz283MZ2a+ysrKIDZDSncdYJV/P3OmhL4AYHfJSuvNU1+Zwq9mT2DTnoNc/uD7/PbdT2nU2oXI5xaOi9nfA65wzuUATwG/bquTc+4J51yhc64wIyN87ywOJ8W+8CoA2F3MjOsm5/DmXedz0ZgM7nt9A9c++iHrd9aGemoiESmYoKgAWp7QzvHa2uxjZgkEThlVdfDeNtvNLAOY4Jxb6rUXAecEtSXSoWMFAC8tyCS9b2Kop9MjBqck89j/mswjN53Bjv1HuPI3f+e/Xl2v+y5ETlAwQVEC5JvZCDNLJLA4vbBVn4XALd7j64HFLvCtMwuBud5VUSOAfGBZB2PuA9LMbLQ31qXA+s+/eXLM2+t3U32ontlRtIgdDDPjC+OzeOuuC5hTmMMT721m+q/+xmurd+qLkUSClNBZB+dco5ndCSwC4oEnnXNrzexewOecWwj8AXjWzMqAagK/+PH6FQPrgEbgDudcE0BbY3rtXwdeMLNmAsHxtS7d4hhV5PMHCgDmx+ZpugF9E/nvL45ndmEu//bSGr753MdcMDqDn1w9NuLuThfpaRYN/6oqLCx0Pp8v1NMIWztrjnDuzxfzrQtH8YMZY0I9nZBrbGrm2Y+28as3NlLf1Mw3LjiFb1wwkj6Jnf67SSSqmNly51xhZ/3CcTFbutgLy70CgIU5oZ5KWEiIj+Or547g7e9fwMyxQ3jo7U1cfP/fePHjct17IdIGBUWUa252FPvKOWtkOsMH6hRLS5mpyTx04yQWfONsMlOTuKt4Fdc8+gG+rdWhnppIWFFQRLmlW6rZXh25BQB7QmFeOi9961weuGECe2qPcv1jS7jj+Y/xVx8O9dREwoJOyka5+T4/KcmRXQCwJ8TFGddOymHG2CE8/rfNPP7ep7y5djdfOmsYd1w0ikFh9J3iIj1NRxRRrLaugVfX7OTqCUNJ7hX5BQB7Qp/EBL536Wje+cGFXDspm6c/3MoFv3iHX7+5kQN1DaGenkhIKCii2CurdgQKAOq00wnLSuvNfdeP543vXcAFYzJ46O1NnP+Ld/j9+5upa2gK9fREepSCIooVl/g5dUgK47KjrwBgTxk1uB+PfmkyC+88l9Oz0/jpX9dz0f3v8vzS7dQ3qn6UxAYFRZTasKuWVeU1zCmM3gKAPWl8Tn+evfVMnr/tTDJTk/k/L63mwl++w7NLtnK0UUcYEt0UFFGquKScXvHGNVFeALCnnTNqEC996xye+dpUsvr35v++vJYLfvEuf/xgi05JSdRSUEShQAHAci4rGBIzBQB7kplx/ugMFnzjbJ677UyGpffhx6+sY5q3hnG4XkUHJbro8tgo9Nb63ew73KA7sbuZmXHuqEGcO2oQH22u4qG3N/HTv67n4XfKuPms4dx8Tp4uq5WooKCIQkUlfrLSkpkWowUAQ+GskQM5a+RAlm+r5rG/beahxWU8/t5mrp+cw9enjVThQYloCooos2P/Ed7bVMmdF40iPk6L2D1t8vB0fndzOmV7DvL79zcz31fO88u2M3PsEG4/fySThg0I9RRFTpiCIsq8sLwc52D2ZN07EUqjBvfj59eN565LR/PHD7fy/z7axmtrdjE1L51bzsnjsrGZ9IrXEqFEBpUZjyLNzY4L73+XnAG9ef7rZ4V6OtLCwaONzFu2naeXbMVffYQhqcl8+ezhzJ2Sy0CtY0iIqMx4DPpoSxXbqw8zJ8a+xS4S9EtK4LZpI3n3Bxfxu5sLGTW4H79cVMrZP1/M94tXsbq8JtRTFGmXTj1Fkfm+clKSE5h5+pBQT0XaER9nXFqQyaUFmZTtOcDTH27jhY/LeeHjcs4Y1p8vnz2cy0/PUm0uCSs69RQlao40MPVnbzG7MIefXjMu1NORE1Bb18ACXznPLNnK1qrDpPXuxbWTsrlx6jDGDEkJ9fQkigV76klHFFHilVU7ONrYzA2Fw0I9FTlBqcm9+Np5I/jKOXl8tLmK55dt57ml2/jjh1s5Y1h/bpw6jCvHD6V3oo4yJDR0RBElrn7479Q3NvPad6aptlMUqDp4lBc/ruBPJdvZXHmIlOQErpmYzQ1Tchk7NFX7WLqEjihiyPqdtXxSXsN/XlWgXyBRYmC/JL5+/khumzaCZVuq+dOy7RT5/Dz70TbGZKZw3eRsZk3MJjM1OdRTlRigoIgCxT4/ifFxXDNRBQCjjZlx5siBnDlyID8+XM8rn+zkxY/L+a9XN/Dz1zZwXn4G152RzWUFQ3RqSrqNgiLCHW1s4s8rKrh0bCYDVAAwqvXvk8iXzxrOl88azqeVB3np4wpeWlHBd+atpF9SAl8Yl8UXz8hmSl46cborX7qQgiLCvbVuD/sON+jeiRhzSkY/fjBjDHddOpqPtlTx4scV/OWTHRT5AnW+rhyfxZXjhzI+J02nI+WkaTE7wt385DLKdh/g/X+9WLWdYtzh+kbeWLubv3yyg79trKShyTEsvQ9XTcjiqglDGZOZotCQz9BidgzYsf8I72+q5NsqAChAn8QErpmUzTWTsqk53MCitbt45ZMdPPa3zTzyzqeMGtyPq8YP5coJWZyS0S/U05UIoqCIYAuOFQDUaSdpJa1PL+ZMyWXOlFz2HjzKa2t28ZdVO/iftzfywFsbOXVICpeNHcKMsZkUZOlyW+mYTj1FqOZmxwX3v8Ow9D48d5sKAEpwdtXU8erqnby+dhe+rdU0O8hN782MgiHMOH0IZwwboKPTGKJTT1Huo81V+KuP8IPLxoR6KhJBhqQl87XzRvC180aw9+BR3lq3m0Vrd/HMkm38/u9bGNQviUsLMpl5+hDOHjmQxATVDRUFRcQq9vlJTU5gxlgVAJTPZ1C/JOZOHcbcqcM4UNfAO6WVLFq7i5dXVvCnZdtJSUrg/NEZXHTqYC4ck6GvdY1hQQWFmc0EHgTigd87537e6vUk4BlgMlAF3OCc2+q9dg9wK9AE/ItzblFHY1rgZOlPgdnee37rnHvo5DYzutQcaeC1NbuYU5irKqPSJVKSe3H1hKFcPWEodQ1NfFC2lzfW7uad0j38dfVOzGBCTn+mnzqYi04drDIiMabToDCzeOAR4FKgHCgxs4XOuXUtut0K7HPOjTKzucB9wA1mVgDMBcYCQ4G3zGy09572xvwKkAuc6pxrNrPBXbGh0WThsQKAU7SILV0vuVc800/LZPppmTQ3O9btrOXt9XtYXLqHX725kV+9uZEhqclcdGoGF5+aybmjBtInUScnolkwe3cqUOac2wxgZvOAWUDLoJgF/Nh7vAB42DsymAXMc84dBbaYWZk3Hh2M+U3gJudcM4Bzbs/n37zoVFzi57SsVMYOTQ31VCTKxcUZp2encXp2Gt+5JJ89B+p4t7SSdzbsYeHKHfxpmZ/EhDim5qUzLX8Q54/O4NQhul8j2gQTFNmAv8XzcuDM9vo45xrNrAYY6LV/1Oq9xwoStTfmKQSORq4FKgmcrtrUelJmdjtwO8CwYbFTWnvdjlpWV9TwYxUAlBAYnJLMnMJc5hTmUt/YTMnWahZv2MP7myr579c28N+vbSAjJYlpowYxbfQgzhuVQUaK1jYiXTgeLyYBdc65QjP7IvAkMK11J+fcE8ATELg8tmenGDrHCgDOUgFACbHEhDjOHTWIc0cNAgKX3r6/qZL3N+3l3Y2VvLiiAoDTslI5P38Q0/IzKMwboHW1CBRMUFQQWDM4Jsdra6tPuZklAGkEFrU7em977eXAi97jl4CngphjTDja2MSfV1ZwmQoAShgakpbM7MJcZhfmHl/beG9TJe9v3MuTH2zh8fc2k5QQR2HeAM4eOZCzRg5kfE5/XYIbAYIJihIg38xGEPhlPhe4qVWfhcAtwBLgemCxc86Z2ULgeTP7NYHF7HxgGWAdjPln4CJgC3ABsPFzb12UeXPdbvarAKBEgJZrG9+6cBSHjjaybEs172/ay5LNVdz/RuB/69694gPBccpAzh45kHHZaSTEKzjCTadB4a053AksInAp65POubVmdi/gc84tBP4APOstVlcT+MWP16+YwCJ1I3CHc64JoK0xvY/8OfCcmX0POAjc1nWbG9mKSvxk9+99/FBfJFL0TUrgIu/SWoB9h+pZuqWKJZ9WsWRzFb94vRSAfkkJTDkeHIMoGJqqO8XDgEp4RIiK/Uc4777FfPvifO66dHTnbxCJIHsPHuWjzf8Ijs2VhwBISUrgjOEDmJI3gMK8dCbm9tcaRxdSCY8os8BXDsDsyTkhnolI1xvUL4krxw/lyvFDAdhdW8dHm6tYtqUa39Z9x09V9YoPnNKakpdO4fBAeKRrva7bKSgiQHOzY/5yP+eeMojc9D6hno5It8tMTWbWxOzjV/ftP1zP8m37KNm6D9/Wav74wVaeeG8zAKdk9A0Ehxcewwf20aXjXUxBEQGWbK6ifN8RfjhDBQAlNvXvk3j8bnGAuoYmVlfUULI1cMTx6uqdzCsJ3JqV3jeRibn9mZTbn0nDBjA+N43U5F6hnH7EU1BEABUAFPms5F7xTMlLZ0peOhA46t605yC+bdWs3L6fFf79LN4QKOpgBqMy+gXCY9gAJub2Z3RmP11ddQIUFGGu5nCgAODcKSoAKNKeuDhjzJAUxgxJ4UtnDgcCxTM/Kd/Piu37Wenfz1vrdzN/eWCtr09iPONz0piYGwiOCblpDElN1imrdigowtzCVRXUNzbr3gmRE5TWuxfT8jOYlp8BgHOObVWHWenfz4rt+1jp38/v399MY3Pgys9B/RI5PTuN8d79H+NyFB7HKCjCXJHPT0FWKqdnp4V6KiIRzczIG9SXvEF9uWZSYJG8rqGJtTtqWVNRw+qKGlaX1/Dexkq87GBQvyTGZacyLqc/47LTGJ+TRmZqcgi3IjQUFGFs7Y4a1lTU8pOrx4Z6KiJRKblXPJOHD2Dy8AHH247UN7FuZyA0VlfUsrpiP39rER4ZKUmM8446CrwqzjkDekf1kYeCIozN95WTmBDHrIlDQz0VkZjROzGeycPTmTw8/Xjb4fpG1u+s5ZPywJHHmooa3i3dczw8UpISOC0rldOyUjgtK5WCoamMzkyJmnVFBUWYqmto4qUVFcwYO4T+fXRDkUgo9UlM+KfwOFLfROnuA6zbUcv6nbWs21nLguXlHKpvAiDO4JSMfseD41iQDE6JvFNXCoow9ea63dQcaWBOoe7EFglHvRPjmZjbn4m5/Y+3NTc7tlcfZv3Of4TH8m37WLhqx/E+g/olcVpWCmMyUxg9JIXRmSnkD+5H36Tw/XUcvjOLccU+rwDgKSoAKBIp4uL+sWB++bis4+37D9ezfueB4+Gxfmctz360jaONzcf75Kb3ZkxmCvmZXohkpnDK4L4kJYT+9JWCIgyV7zvM38v28p3p+cSpcqZIxOvfJzFQEfeUgcfbmpod/urDlO4+wMZdByjdfYBNuw/ybmnl8Ut24+OMvIF9GO0Fx5ghKYzO7EfewL49esOggiIMLfBuCrpeBQBFolZ8i6OPllUX6hub2Vp1iI0tAmTDrgMsWrvr+OJ5YnwcIwb1ZdTgftx9+andXgNOQRFmmpsd833lnDdqEDkDVABQJNYkJsQdP4Jg/D/a6xqaKNtzMBAguw9Stucga3fU9Mg3BCoowsyHn1ZRsf8I/3r5qaGeioiEkeRe8ce/NbCnqSpWmCn2+Unr3YvLCjJDPRUREUBBEVZqDjfw+tpdXDNxaNTcqCMikU9BEUZePlYAcIoKAIpI+FBQhJGiEj9jh6YydqgKAIpI+FBQhIk1FTWs3VHLDTqaEJEwo6AIE/N9/kABwAnZoZ6KiMhnKCjCQF1DE39euYOZY4eQ1kff7Ssi4UVBEQbeOF4AUKedRCT8KCjCQHGJn5wBvTmnRR0YEZFwoaAIMX/1YT74dC+zJ+eqAKCIhCUFRYgdLwCo750QkTCloAih5mbHguWBAoDZ/XuHejoiIm1SUITQB5/upWL/ES1ii0hYU1CEULGvnP59enHZWBUAFJHwpaAIkf2H61m0dhfXTMwOi686FBFpT1BBYWYzzazUzMrM7O42Xk8ysyLv9aVmltfitXu89lIzm3ECYz5kZgc/53aFvZdX7ggUANRpJxEJc50GhZnFA48AlwMFwI1mVtCq263APufcKOAB4D7vvQXAXGAsMBN41MziOxvTzAqBASe5bWGtqMTP6dmpFAxNDfVUREQ6FMwRxVSgzDm32TlXD8wDZrXqMwt42nu8AJhuZua1z3POHXXObQHKvPHaHdMLkV8CPzq5TQtfaypqWLezlht0NCEiESCYoMgG/C2el3ttbfZxzjUCNcDADt7b0Zh3Agudczs7mpSZ3W5mPjPzVVZWBrEZ4aPYKwB4tQoAikgECKvFbDMbCswGftNZX+fcE865QudcYUZGRvdProvUNTTx5xUVXH66CgCKSGQIJigqgJbnSHK8tjb7mFkCkAZUdfDe9tonAaOAMjPbCvQxs7IgtyUiLFq7i9q6Ri1ii0jECCYoSoB8MxthZokEFqcXtuqzELjFe3w9sNg557z2ud5VUSOAfGBZe2M65/7qnBvinMtzzuUBh70F8qhR7POTm96bs0eqAKCIRIaEzjo45xrN7E5gERAPPOmcW2tm9wI+59xC4A/As96//qsJ/OLH61cMrAMagTucc00AbY3Z9ZsXXvzVh/mgrIq7Lh2tAoAiEjE6DQoA59yrwKut2v6jxeM6AmsLbb33Z8DPghmzjT79gplfpJi/vBwzuG6yCgCKSOQIq8XsaNbU7Fjg8zMtP0MFAEUkoigoesgHZXvZUVPHHJUTF5EIo6DoIcU+P/379OLSAhUAFJHIoqDoAfsO1fPG2t0qACgiEUlB0QNeXllBfZMKAIpIZFJQdDPnHEW+csZlp6kAoIhEJAVFN1tTUcv6nbXMmaKjCRGJTAqKblbs85OUEMfVE4aGeioiIp+LgqIb1TU08eeVXgHA3ioAKCKRSUHRjRat3cWBukaddhKRiKag6EZFJYECgGeNUAFAEYlcCopu4q8+zIefVjFncq4KAIpIRFNQdJP5Pr8KAIpIVFBQdIOmZseC5eWcn5/BUBUAFJEIp6DoBn8/XgBQi9giEvkUFN2g2OdnQJ9eXFIwONRTERE5aQqKLrbvUD1vrt3NNZNUAFBEooOCoou9tCJQAPAG3TshIlFCQdGFnHMU+/yMz0nj1CEqACgi0UFB0YVWV9SwYdcBLWKLSFRRUHShYwUAr1IBQBGJIgqKLlLX0MTLK3dwxbgsFQAUkaiioOgir6/xCgDqtJOIRBkFRRcpKvEzLL0PZ45ID/VURES6lIKiC2yvOsySzVXMKcxRAUARiToKii4wf7mfOBUAFJEopaA4SccLAI7OICtNBQBFJPooKE7S+5sq2akCgCISxRQUJ2m+r5z0volcclpmqKciItItFBQnofpQPW+s28U1E7NJTNBfpYhEp6B+u5nZTDMrNbMyM7u7jdeTzKzIe32pmeW1eO0er73UzGZ0NqaZPee1rzGzJ80sbO9ee2lFBQ1NTgUARSSqdRoUZhYPPAJcDhQAN5pZQatutwL7nHOjgAeA+7z3FgBzgbHATOBRM4vvZMzngFOBcUBv4LaT2sJu4pxjvs/PhJw0xgxJCfV0RES6TTBHFFOBMufcZudcPTAPmNWqzyzgae/xAmC6mZnXPs85d9Q5twUo88Zrd0zn3KvOAywDwvKa00/KvQKAOpoQkSgXTFBkA/4Wz8u9tjb7OOcagRpgYAfv7XRM75TTl4HX25qUmd1uZj4z81VWVgaxGV2r2OcnuZcKAIpI9AvnFdhHgfecc++39aJz7gnnXKFzrjAjI6NHJ3akvomFK3dwxelZpCaH7RKKiEiXSAiiTwXQ8vxKjtfWVp9yM0sA0oCqTt7b7phm9p9ABvC/g5hfj3t97U4OHG3UaScRiQnBHFGUAPlmNsLMEgksTi9s1WchcIv3+HpgsbfGsBCY610VNQLIJ7Du0O6YZnYbMAO40TnXfHKb1z2KSvwMH6gCgCISGzo9onDONZrZncAiIB540jm31szuBXzOuYXAH4BnzawMqCbwix+vXzGwDmgE7nDONQG0Nab3kY8B24AlgfVwXnTO3dtlW3yStlUd4qPN1fxwxhi8+YmIRLVgTj3hnHsVeLVV23+0eFwHzG7nvT8DfhbMmF57UHMKlfm+8kABwDPC8mIsEZEuF86L2WHnWAHAC0ZnMCQtOdTTERHpEQqKE/Depkp21aoAoIjEFgXFCZjv85PeN5HpKgAoIjFEQRGkqoNHeXPdbq6dpAKAIhJb9BsvSMcKAOq0k4jEGgVFEJxzFPv8TMjtrwKAIhJzFBRBWFVew8bdB7lBRxMiEoMUFEH4RwHArFBPRUSkxykoOnGkvolXVu7ginFZpKgAoIjEIAVFJ15bEygAqNNOIhKrFBSdKCrxkzewD1NVAFBEYpSCogNb9x5i6ZZqZhfmqgCgiMQsBUUH5i/3qwCgiMQ8BUU7jhUAvHDMYBUAFJGYpqBox3sbK9lde5Q5hTqaEJHYpqBoR1GJn4F9E7n4VBUAFJHYpqBoQ9XBo7y1XgUARURAQdGml1ZU0NjsmDNF906IiCgoWnHOUVTiZ2Juf0ZnqgCgiIiCopWV/v1s2nOQG3Q0ISICKCj+SbGvnN694rlyvAoAioiAguIzDtc38soqFQAUEWlJQdHCa6t3cfBoo047iYi0oKBoocjnZ8SgvkzJGxDqqYiIhA0FhWfL3kMs21LN7MIcFQAUEWlBQeGZ71MBQBGRtigogMamZl74uJyLxgwmM1UFAEVEWlJQAO9tChQAnK1vsRMR+ScKCgIFAAf1S2T6aYNDPRURkbAT80Gx9+BR3l6/h2snZdMrPub/OkRE/knM/2Z86eNAAUDdOyEi0raggsLMZppZqZmVmdndbbyeZGZF3utLzSyvxWv3eO2lZjajszHNbIQ3Rpk3ZuJJbmO7nHMU+/ycMaw/owarAKCISFs6DQoziwceAS4HCoAbzaygVbdbgX3OuVHAA8B93nsLgLnAWGAm8KiZxXcy5n3AA95Y+7yxu8UKrwDgHC1ii4i0K5gjiqlAmXNus3OuHpgHzGrVZxbwtPd4ATDdAnetzQLmOeeOOue2AGXeeG2O6b3nYm8MvDGv+dxb14n5Pn+gAOCEod31ESIiES+YoMgG/C2el3ttbfZxzjUCNcDADt7bXvtAYL83RnufBYCZ3W5mPjPzVVZWBrEZ/2xYel++cm4e/ZISPtf7RURiQcT+hnTOPQE8AVBYWOg+zxjfvPCULp2TiEg0CuaIogJoeRI/x2trs4+ZJQBpQFUH722vvQro743R3meJiEgPCiYoSoB872qkRAKL0wtb9VkI3OI9vh5Y7JxzXvtc76qoEUA+sKy9Mb33vOONgTfmy59/80RE5GR1eurJOddoZncCi4B44Enn3FozuxfwOecWAn8AnjWzMqCawC9+vH7FwDqgEbjDOdcE0NaY3kf+KzDPzH4KrPDGFhGRELHAP+IjW2FhofP5fKGehohIRDGz5c65ws76xfyd2SIi0jEFhYiIdEhBISIiHVJQiIhIh6JiMdvMKoFtn/Ptg4C9XTidSKBtjg3a5uh3sts73DmX0VmnqAiKk2FmvmBW/aOJtjk2aJujX09tr049iYhIhxQUIiLSIQWFV1gwxmibY4O2Ofr1yPbG/BqFiIh0TEcUIiLSIQWFiIh0KKaDwsxmmlmpmZWZ2d2hns+JMLNcM3vHzNaZ2Voz+47Xnm5mb5rZJu+/A7x2M7OHvG39xMzOaDHWLV7/TWZ2S4v2yWa22nvPQ95X1Yac973rK8zsL97zEWa21JtnkVe6Hq+8fZHXvtTM8lqMcY/XXmpmM1q0h93PhJn1N7MFZrbBzNab2dnRvp/N7Hvez/UaM/uTmSVH2342syfNbI+ZrWnR1u37tb3P6JBzLib/EChv/ikwEkgEVgEFoZ7XCcw/CzjDe5wCbAQKgF8Ad3vtdwP3eY+vAF4DDDgLWOq1pwObvf8O8B4P8F5b5vU1772Xh3q7vXndBTwP/MV7XgzM9R4/BnzTe/wt4DHv8VygyHtc4O3vJGCE93MQH64/EwS+O/4273Ei0D+a9zOBrz/eAvRusX+/Em37GTgfOANY06Kt2/dre5/R4VxD/T9BCH8YzwYWtXh+D3BPqOd1EtvzMnApUApkeW1ZQKn3+HHgxhb9S73XbwQeb9H+uNeWBWxo0f6ZfiHczhzgbeBi4C/e/wR7gYTW+5XA952c7T1O8PpZ6319rF84/kwQ+LbILXgXnrTef9G4nwkEhd/75Zfg7ecZ0bifgTw+GxTdvl/b+4yO/sTyqadjP4zHlHttEcc71J4ELAUynXM7vZd2AZne4/a2t6P28jbaQ+1/gB8Bzd7zgcB+51yj97zlPI9vm/d6jdf/RP8uQmkEUAk85Z1u+72Z9SWK97NzrgK4H9gO7CSw35YT3fv5mJ7Yr+19RrtiOSiigpn1A14Avuucq235mgv8kyFqrn82syuBPc655aGeSw9KIHB64rfOuUnAIQKnC46Lwv08AJhFICSHAn2BmSGdVAj0xH4N9jNiOSgqgNwWz3O8tohhZr0IhMRzzrkXvebdZpblvZ4F7PHa29vejtpz2mgPpXOBq81sKzCPwOmnB4H+Znbsa31bzvP4tnmvpwFVnPjfRSiVA+XOuaXe8wUEgiOa9/MlwBbnXKVzrgF4kcC+j+b9fExP7Nf2PqNdsRwUJUC+dyVFIoFFsIUhnlPQvCsY/gCsd879usVLC4FjVz7cQmDt4lj7zd7VE2cBNd7h5yLgMjMb4P1L7jIC5293ArVmdpb3WTe3GCsknHP3OOdynHN5BPbXYufcl4B3gOu9bq23+djfxfVef+e1z/WulhkB5BNY+Au7nwnn3C7Ab2ZjvKbpBL6DPmr3M4FTTmeZWR9vTse2OWr3cws9sV/b+4z2hXLRKtR/CFxJsJHAFRD/Fur5nODczyNwyPgJsNL7cwWBc7NvA5uAt4B0r78Bj3jbuhoobDHW14Ay789XW7QXAmu89zxMqwXVEG//hfzjqqeRBH4BlAHzgSSvPdl7Xua9PrLF+//N265SWlzlE44/E8BEwOft6z8TuLolqvcz8BNggzevZwlcuRRV+xn4E4E1mAYCR4639sR+be8zOvqjEh4iItKhWD71JCIiQVBQiIhIhxQUIiLSIQWFiIh0SEEhIiIdUlCIiEiHFBQiItKh/w/uhegfvR+Q7QAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "xs = list(range(100000))\n", + "plt.plot(xs, lrs)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "4f4e282c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspace/wenet/venv/lib/python3.8/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", + " and should_run_async(code)\n" + ] + } + ], + "source": [ + "from typing import Union\n", + "\n", + "from paddle.optimizer.lr import LRScheduler\n", + "from typeguard import check_argument_types\n", + "\n", + "class WarmupLR(LRScheduler):\n", + " \"\"\"The WarmupLR scheduler\n", + " This scheduler is almost same as NoamLR Scheduler except for following\n", + " difference:\n", + " NoamLR:\n", + " lr = optimizer.lr * model_size ** -0.5\n", + " * min(step ** -0.5, step * warmup_step ** -1.5)\n", + " WarmupLR:\n", + " lr = optimizer.lr * warmup_step ** 0.5\n", + " * min(step ** -0.5, step * warmup_step ** -1.5)\n", + " Note that the maximum lr equals to optimizer.lr in this scheduler.\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " warmup_steps: Union[int, float]=25000,\n", + " learning_rate=1.0,\n", + " last_epoch=-1,\n", + " verbose=False):\n", + " assert check_argument_types()\n", + " self.warmup_steps = warmup_steps\n", + " super().__init__(learning_rate, last_epoch, verbose)\n", + "\n", + " def __repr__(self):\n", + " return f\"{self.__class__.__name__}(warmup_steps={self.warmup_steps})\"\n", + "\n", + " def get_lr(self):\n", + " step_num = self.last_epoch + 1\n", + " return self.base_lr * self.warmup_steps**0.5 * min(\n", + " step_num**-0.5, step_num * self.warmup_steps**-1.5)\n", + "\n", + " def set_step(self, step: int):\n", + " self.step(step)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "8c40b202", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-1\n" + ] + } + ], + "source": [ + "sc = WarmupLR(warmup_steps=25000, learning_rate=0.001)\n", + "print(step)\n", + "#sc.set_step(step)\n", + "sc.set_step(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "ecbc7e37", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqaUlEQVR4nO3de3xU9Z3/8dcnCUm4JIGEEAIBEiCAQW4SEG94F7QqagGhu9Varb9a3W51267+tr/dtrvdVevW1VardrVaa4WAN7QqKqJ4QchwvwYiAZMQICQQ7uT2/f0xB4xpLoMkmcnM+/l48GDmO99z5ns4Yd4553vOZ8w5h4iISHOigj0AEREJbQoKERFpkYJCRERapKAQEZEWKShERKRFMcEeQFvo3bu3y8zMDPYwREQ6lRUrVux1zqW21i8sgiIzMxOfzxfsYYiIdCpmtiOQfjr1JCIiLVJQiIhIixQUIiLSIgWFiIi0SEEhIiItCigozGyqmRWYWaGZ3dvE63FmNtd7fZmZZTZ47T6vvcDMpjRof8bM9pjZ+kbrSjazd81sq/d3r9PYPhEROU2tBoWZRQOPAVcCOcBsM8tp1O1WYJ9zbijwMPCAt2wOMAsYCUwFHvfWB/Cs19bYvcAi51w2sMh7LiIiQRLIEcVEoNA5t805Vw3MAaY16jMNeM57PB+41MzMa5/jnDvunCsCCr314ZxbAlQ28X4N1/UccF3gmyPtaVv5IT4o2BPsYYhIBwskKPoDxQ2el3htTfZxztUCVUBKgMs2luacK/Me7wLSmupkZrebmc/MfOXl5QFshpyuWU99xnf+mM+iTbuDPRQR6UAhPZnt/N+q1OQ3KznnnnLO5TrnclNTW70DXU7T1t0H2XPwOAA/mrOaz8sPBXlEItJRAgmKUmBAg+cZXluTfcwsBkgCKgJctrHdZpburSsd0LmOEJDnKyYmynj9rvPpEhPF7X/ycfBYTbCHJSIdIJCgyAeyzSzLzGLxT04vaNRnAXCz93g68L53NLAAmOVdFZUFZAPLW3m/huu6GXgtgDFKO6qpq+fllaVcdkYaozKSeOxbZ7G94gj35K2hvl5fpSsS7loNCm/O4S5gIbAJyHPObTCzX5rZtV63p4EUMysE7sG7Usk5twHIAzYCbwN3OufqAMzsRWApMNzMSszsVm9d9wOXm9lW4DLvuQTRok17qDhczcwJGQCcMySFf7nqDN7duJtH398a5NGJSHsz/y/+nVtubq5T9dj2c+uz+azfWcUn/3wJMdH+3y2cc/x43lpeWlnCI7PGMm1sa9coiEioMbMVzrnc1vqF9GS2BN/uA8dYXLCHb56VcTIkAMyM/7zhTCZmJfOTeWvJ397Ulc4iEg4UFNKil1aWUO9gZu6Av3ktLiaap749noxeXbn9Tz627z0chBGKSHtTUEiznHPM85UwMSuZzN7dm+zTs1ssf7xlAmbGLc/ms+9wdQePUkTam4JCmrW8qJKivYe5sYmjiYYGpXTnDzeNp3T/Ub73Jx9Hq+s6aIQi0hEUFNKsPF8JPeJiuHJU31b7jh+UzP/cOJYVX+zjBy+soKauvgNGKCIdQUEhTTp4rIY315VxzZh+dIsN7KvVrxqVzn9eP4rFBeX8eJ7usRAJF4F9AkjEeWNtGUdr6rhxQsunnRqbPXEg+45U8+DbBSR17cIvrh2Jvz6kiHRWCgpp0tz8Yoal9WBMRtIpL3vHhUPYf6SGp5Zso2fXLtxzxfB2GKGIdBQFhfyNLbsPsrp4Pz/7xhlf62jAzLjvyhFUHanh0fcLiY2J4q5LstthpCLSERQU8jfy8ovpEm1cP+7r323tvyFvFDV19Tz0zhbMjDsvHtqGoxSRjqKgkK+orq3nlVX+AoApPeJOa13RUcavZ4zBAb9eWACgsBDphBQU8hXvb97tLwDYyr0TgYqOMh6aMQbwh4UZ/OAihYVIZ6KgkK/I85XQNzGeycPa7sugToSFc44H3y6gptbxw0uH6mookU5CQSEn7ao6xgcFe7jjoiFER7Xth3h0lPHfM8cSEx3Fw+9toepoDT/7xhlEtfH7iEjbU1DISScKAM4Y3zannRqLjjIe/OZoEuJjeOaTIg4cq+H+G0Z9pSqtiIQeBYUAJwoAFnN2CwUA20JUlPGvV+eQ1LUL//PeVg4dq+WR2WOJi4lut/cUkdOjX+UEgGVFlWyvOHLKd2J/HWbGjy4bxr9encPbG3bx3WfzOaDv3xYJWQoKASDPV0xCXAxXnpneYe/53fOz+O8ZY1i2rZIZv1/Kzv1HO+y9RSRwCgrhwIkCgGP70TW2Y08BfXN8Bs/eMpGd+49y3WOfsL60qkPfX0Rap6AQ3lhTxrGa+ja7d+JUnZ/dm3l3nENMlHHjk0tZXLAnKOMQkaYpKIS5vmKGpyV8rQKAbWVE30ReufM8BqV057bnfDy/dDvOqUy5SChQUES4gl0HWVO8n5kTBgT9Bri0xHjyvn8OFw5L5f+9toH/+8o6qmv1BUgiwaagiHB5vtMvANiWesTF8IebcvnBRUN4cXkxs//wGXsOHgv2sEQimoIigp0oAHh5ThrJ3WODPZyToqOMn04dwe++NY6NOw9w7W8/YW3J/mAPSyRiKSgi2KJNu6k8XM2MIE1it+bq0f2Yf8c5REcZ059YyjxfcbCHJBKRFBQRLM9X7C8AmN12BQDb2sh+SSy46zxyB/XiJ/PX8uN5azhaXRfsYYlEFAVFhNpVdYwPt5QzfXxGmxcAbGspPeJ4/taz+eElQ3lpZQnTHvuYwj0Hgz0skYihoIhQJwsA5mYEeygBiY4y7rliOM/dMpGKQ9Vc+7tPeGVVSbCHJRIRFBQRqL7ekecrZtLgZAaltF8BwPYweVgqf/3hBZzZL4m7567hp/PXcPh4bbCHJRLWFBQRaPn2SnZ0UAHA9tA3KZ6/fO9sfnDREOatKOGqRz9i5Rf7gj0skbCloIhAefn+AoBTR3ZcAcC2FhMdxU+njmDO9yZRW+eY8cRSHn53C7V1ukFPpK0FFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTWlunmV1qZivNbLWZfWxm+oLlNnTgWA1vri/j2iAUAGwPZw9O4a0fXcC0Mf14ZNFWpj+xlKK9h4M9LJGw0mpQmFk08BhwJZADzDaznEbdbgX2OeeGAg8DD3jL5gCzgJHAVOBxM4tuZZ2/B/7OOTcW+Avws9PaQvmK19fsDGoBwPaQGN+F39w4lt/OHse28kNc9chHPPtJEfX1qhUl0hYCOaKYCBQ657Y556qBOcC0Rn2mAc95j+cDl5q/cNA0YI5z7rhzrggo9NbX0jodkOg9TgJ2fr1Nk6bk5Rczom8Co4NYALC9XDOmHwvvnszErGR+/vpGZj65lG3lh4I9LJFOL5Cg6A80vCW2xGtrso9zrhaoAlJaWLaldd4GvGlmJcC3gfubGpSZ3W5mPjPzlZeXB7AZsnnXAdaUVDEzN/gFANtLelJXnr1lAg/NGMOW3QeZ+shHPPHh55q7EDkNoTiZfTdwlXMuA/gj8JumOjnnnnLO5TrnclNTQ/fO4lCSl19Cl2jjuhApANhezIzp4zN4754LuXh4Kve/tZkbfv8pm3cdCPbQRDqlQIKiFGh4QjvDa2uyj5nF4D9lVNHCsk22m1kqMMY5t8xrnwucG9CWSIv8BQBLuCKnb0gVAGxPfRLjeeLvx/O7b42jdN9RvvHox/znm5t034XIKQokKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzv/N86swCY5V0VlQVkA8tbWOc+IMnMhnnruhzY9PU3T054b9Nu9h2p6TR3YrcVM+Pq0f14754LmTE+g6eWbOOy33zIW+vK9MVIIgGKaa2Dc67WzO4CFgLRwDPOuQ1m9kvA55xbADwNPG9mhUAl/g9+vH55wEagFrjTOVcH0NQ6vfbvAS+ZWT3+4Phum25xhMrzFZOeFM8FIVwAsD316h7L/d8czYzcAfzs1fXc8cJKLhyWyi+uHUlm7851d7pIR7Nw+K0qNzfX+Xy+YA8jZJVVHeW8+9/nzouH8k9XDA/2cIKutq6ePy3dwW/e3UJ1XT3fv3AI379wMN1iW/29SSSsmNkK51xua/1CcTJb2thLK7wCgOPD596J0xETHcV3z89i0T9dyJSRfXl00VYueehDXl5ZonsvRJqgoAhz/gKAJZwzOIWBKd2CPZyQkpYYz29nj2Pe98+hT2Ic9+St4brHP8G3vTLYQxMJKQqKMLesqJIvKjtvAcCOMCEzmVd/cB6/mTmG3QeOMf2Jpdz5l5UUVx4J9tBEQoJOyoa5PF8xCfExTD2zb7CHEtKioowbzspg6pl9efLDbTy55HPe3bCbv580iDsvHkJKj7hgD1EkaHREEcaqjtbw5roypo3tR3yXzl8AsCN0i43h7suHsfjHF3H9uP48+2kRkx9czMPvbuHgsZpgD08kKBQUYez1NTs5XhteBQA7SnpSVx6YPpp37p7M5GGpPLJoKxf++gOe/riIYzX6zm6JLAqKMJbn8xcAHNU//AoAdpShfRL4/d+P57U7zyMnPZF/f2Mjlzz0AS8u/4LqWtWPksigoAhTm8oOsDbMCwB2pDEDevLn287mhdvOJjUxnvteXsfFD33Anz/bwfFaHWFIeFNQhKk8XzGx0VFcH+YFADvaeUN78+oPzuXZWybQJzGOn726ngsf/IDnPt2uU1ISthQUYeh4bR2vrirl8pFp9IqQAoAdycy4aHgfXr7jXP5869kMSO7Kvy3YwOQHF/P0x0UcrVZgSHjR5bFh6L2Ne9h3pEaT2O3MzDg/uzfnDU1h6bYKHl20lX9/YyO/e38rN52TyU3nDNJltRIWFBRhKM9XTL+keM4f2jvYQ4kIZsa5Q3pz7pDe5G+v5MkPP+eRRVt5csnnzBg/gNsuyGJQigoPSueloAgzO/cfZcnWcv7h4qFER2kSu6NNyExmQmYyhXsO8tSSbczNL+aFZTu48sx0bp88mDEDegZ7iCKnTEERZl5aUYJzMEOnnYJqaJ8EHpw+hh9fMZw/frqdP3+2g7+uK2NiVjLfOTeTK3LSiInWFKF0DiozHkbq6x0XPrSYAb268ZfvTQr2cKSBQ8drmbP8C579dDsl+47SLymev5s0iNkTB0bMNw5K6FGZ8Qj0WVEFxZVHVQAwBPWIi+G2Cwbz4U8u5g835ZKV2p1fLyxg0n8t4sfz1rC+tCrYQxRplk49hZG8fH8BwCkjVQAwVEVHGZfnpHF5Thpbdx/kuaXbeXllKfNXlDB+UC9uOmcQU0b2VW0uCSkKijBRdbSGt9bvYmbuAH3IdBLZaQn8x3Wj+MmUEcxfUcLzS7fzj3NW07NbF24Yl8HsiQPITksI9jBFFBThYoEKAHZaSV27cOv5WdxybiZLt1Xw4vIveP6z7TzzSRG5g3oxe+JArhqVTtdY/QIgwaHJ7DBxzW8/prbe8eYPz1dtpzBQceg4L68s5cXlX7Bt72ES4mO4YVx/Zk4YwMh+KvIobSPQyWwdUYSBjTsPsK60in+7JkchESZSesTxvcmDue2CLJYXVfLi8i94Mb+Y55buYETfBL55VgbTxvajT2J8sIcqEUBBEQZOFAC8bqwKAIYbM+PswSmcPTiFnx+p5vW1Zby0ooRfvbmJ/3prE5OHpXLDWRlckZOmuSlpNwqKTu54bR2vrlYBwEjQs1ss3540iG9PGsTn5Yd4eWUJr6ws5YcvriIhLoZvjE7nhrMyyB3UiyjdlS9tSEHRyb27cTf7j9RwoyaxI8qQ1B78ZMoI/uny4XxWVMFLK0pZsGYnc/L9db6uHtOPq0enM6p/kk5HymnTZHYnd9Mzy/l8zyGW/PRi1XaKcEeqa3lnw25eX7OTJVvLqalzDErpxjWj+3H1mHSGpyUoNOQrNJkdAUr3H+WjreX8wyXZCgmhW2wM143rz3Xj+lN1pIaFG3bx+tqdPP5BIb9bXEh2nx5cPbof14xJZ3Bqj2APVzoRBUUndrIA4PiMYA9FQkxSty7MnDCAmRMGsPfQcd5aV8bra8t4+L0tPPzeFkb0TWDKyL5MGdmXM9J1pCEt06mnTqq+3jH514sZlNKNF25TAUAJTFnVUd5ct4uF63eRv6MS52BgcjemjExjysi+nDVQE+GRRKeewtxn2yoo2XeUn0wZHuyhSCeSntSVW8/P4tbzsyg/eJz3Nu1m4YZdPPvpdv7wURGpCXFcnpPG1JF9mTQ4hdgY1Q0VBUWnNddXTKIKAMppSE2IY/bEgcyeOJADx2pYvHkPCzfs4tVVpfxl2RckxMcweVgql47ow0XD+6gcegQLKCjMbCrwCBAN/K9z7v5Gr8cBfwLGAxXAjc657d5r9wG3AnXAD51zC1tap/lPlv4HMMNb5vfOuUdPbzPDS9URfwHAWRNUAFDaRmJ8F6aN7c+0sf05VlPHx1v38s7GXSwuKOeva8swg3EDenLJiD5cMiJN8xoRptWgMLNo4DHgcqAEyDezBc65jQ263Qrsc84NNbNZwAPAjWaWA8wCRgL9gPfMbJi3THPr/A4wABjhnKs3sz5tsaHhZMGaUqpVAFDaSXyXaC7LSeOynDTq6x3rd1bx/uY9vL95Dw+9s4WH3tlCelI8F4/owyXD+3De0N4qWBjmAjmimAgUOue2AZjZHGAa0DAopgE/9x7PB37nHRlMA+Y4544DRWZW6K2PFtZ5B/At51w9gHNuz9ffvPA011dMTnoiZ/ZXcThpX1FRxuiMnozO6MmPLhvGngPH+KCgnEWbd/Oad4oqLiaKiVnJTM5O5YJhvXW/RhgKJCj6A8UNnpcAZzfXxzlXa2ZVQIrX/lmjZU8UJGpunUPwH41cD5TjP121tfGgzOx24HaAgQMHBrAZ4WHDzirWlx7g59fkBHsoEoH6JMafvOz2eG0dy4sqWby5nI+2lvOrNzfBm/65jwuyezM5O5XzhvYmNSEu2MOW0xSKk9lxwDHnXK6Z3QA8A1zQuJNz7ingKfBfHtuxQwyeeb4SfwHAcSoAKMEVFxPNBdmpXJCdCvgvvf1o614+2rqXxZv38PLKUgBy0hO5YJg/OHIzexEXo9NUnU0gQVGKf87ghAyvrak+JWYWAyThn9Ruadnm2kuAl73HrwB/DGCMEeFYTR2vrCrlipFp9OymK1AktKQndWVm7gBm5g6gvt6xYecBlmwtZ8mWcp75uIgnP9xGfJcocgclc86QFCYNTmF0RhJdonUJbqgLJCjygWwzy8L/YT4L+FajPguAm4GlwHTgfeecM7MFwF/M7Df4J7OzgeWAtbDOV4GLgSLgQmDL1966MPPuxt1UHa3hxgmaxJbQFhVljMpIYlRGEndePJTDx2tZVlTBki17+WxbBb9eWABAt9hoJmQmM2lwCucMSeHMfonEKDhCTqtB4c053AUsxH8p6zPOuQ1m9kvA55xbADwNPO9NVlfi/+DH65eHf5K6FrjTOVcH0NQ6vbe8H3jBzO4GDgG3td3mdm55vmL69+zKeUN6B3soIqeke1wMl4xI45IRaYD/G/yWFVXy2bYKln5ewQNvbwYgIS6GCVnJnOMFxxnpiapjFgJUwqOTKNl3hAseXMwPL8nm7suHtb6ASCdSfvC4PzS2VfDZ5xVs23sYgIT4GMYP6sWEzGRyB/VizICeuneoDamER5h5aYV/CmdGrgoASvhJTYjjmjH9uGZMPwB2VR3js20VLN9eSX5RJR8U+E9VdYk2RvVP8geHFx76wq72pyOKTuBEAcDMlO78+bbGVyaLhL99h6tZsWMf+Tsq8W3fx9qS/dTU+T+7hvbpwYTMXuQOSmZCZjIDkrvqPo4A6YgijCz1CgD+dOqIYA9FJCh6dY89ebc4+K8AXFtSRf72SnzbK3ljbRkvLvffmpXSPZaxA3oybmBPxg3sxeiMJBLiuwRz+J2egqITmJtfTFLXLlzh/ScRiXTxXaKZmJXMxKxkwH/UvWXPQfK372P1F/tZXbyPRZv9RR3MILtPDy88ejF2QE+GpSVokvwUKChCXNWRGt7esIvZKgAo0qyoKGNE30RG9E3k25MGAf7/O2tK9rPKC453Nu4mz1cCQPfYaEZlJJ0MjtEZSfRNjNcpq2YoKELca14BwBkqAChySpK6dWHysFQmD/PfOe6cY0fFEVYV+486VhXv5w9LtlFb75/r6N0jjlH9ExnVP4lRGT0Z1T+JtMQ4hQcKipCX5ytmZD8VABQ5XWZGZu/uZPbuzvXj/FcPHqupY8POA6wvrWJtSRXrS6v4cEs5XnbQu0ccozOSOLN/EqP7+28gTEuMD+JWBIeCIoSdKAD4i2tHBnsoImEpvks04wf1YvygXifbjlTXsqnsAOtKqlhb6g+PDwr2nAyP1IQ4Rvf3h0dOv0Ry0hPJ6BXeV1opKEJYXn4xsTFRTBvbL9hDEYkY3WJjGD8omfGDkk+2HamuZePOA6wrrWJdSRXrSqtY3CA8EuJjOKNvIjn9EjkjPYEz0hMZlpYQNvOKCooQdaymjldX72TKyL4qACgSZN1iY/w3+GV+NTwKdh1kU9lBNpZVsansIHm+Yo5U1wEQHWUM7t3dCw//n5z0xE5Zdl1BEaLeOVEAUJPYIiGpW2wM4wb2YtzAL09b1dc7vqg8wsayA2wqO8DGnQfIL6rktdU7T/bp3SOOM9ITGJ6WwLC+/r+z03rQLTZ0P45Dd2QRbp5XAPDcISnBHoqIBCgq6ssJ86tGpZ9s33e4mk27/MGxqewgm8oO8KeiHVTX1p/sMyC5qz880hIY3tf/9+DU7iHx/R0KihBUsu8IHxfu5R8vzSZKNwWJdHq9usdy7pDenNug8nNdvWNHxWG27D7Ilt2HKNh9kC27DvJBQfnJS3ajo4zMlG4ng+PEn8yUbh1ajl1BEYLmr/DfFDR9vAoAioSr6ChjcGoPBqf2YOqZX7ZX19ZTtPfwyeDYsvsgG3ce4K31uzhRmi82Ooqs3t0ZmtaDe6eOYEByt3Ydq4IixNTXO+b5Sjh/aG8yerXvzheR0BMbE8Xwvv7TT4z5sv1odR2Few55RyAHKdxziHUlVcTGtP+RhYIixHz6eQWl+49y75UqACgiX+rqlR0ZldHxN9/qOwdDzFyfvwDg5SoAKCIhQkERQvYfqWbhhl1cP65/2NyoIyKdn4IihLy2eqdXAFCT2CISOhQUISTPV8yZ/RMZ2U8FAEUkdCgoQsT60io27DzATN2JLSIhRkERIvJ8XgHAMf2DPRQRka9QUISAYzV1vLqqlKkj+5LUTd/tKyKhRUERAhZu2MWBY7XcOEGnnUQk9CgoQsA8XwkZvbpyzmAVABSR0KOgCLLiSn8BwBnjB6gAoIiEJAVFkM1fUYIZTNe9EyISohQUQVRX75i/wl8AsH/PrsEejohIkxQUQfTp53sp3X9Uk9giEtIUFEE0N7+Ynt1UAFBEQpuCIkj2H6nmnQ27uW5s/5D4qkMRkeYEFBRmNtXMCsys0MzubeL1ODOb672+zMwyG7x2n9deYGZTTmGdj5rZoa+5XSHv1VWlVNfVq2SHiIS8VoPCzKKBx4ArgRxgtpnlNOp2K7DPOTcUeBh4wFs2B5gFjASmAo+bWXRr6zSzXKDXaW5bSMvzlTCqfxI5/RKDPRQRkRYFckQxESh0zm1zzlUDc4BpjfpMA57zHs8HLjUz89rnOOeOO+eKgEJvfc2u0wuRXwM/Pb1NC13rS6vYWHaAmbokVkQ6gUCCoj9Q3OB5idfWZB/nXC1QBaS0sGxL67wLWOCcK2tpUGZ2u5n5zMxXXl4ewGaEjjxfMXExUVw7VgUARST0hdRktpn1A2YAv22tr3PuKedcrnMuNzU1tf0H10ZOFgA8sy9JXVUAUERCXyBBUQo0nHHN8Nqa7GNmMUASUNHCss21jwOGAoVmth3oZmaFAW5Lp3CyAKAmsUWkkwgkKPKBbDPLMrNY/JPTCxr1WQDc7D2eDrzvnHNe+yzvqqgsIBtY3tw6nXN/dc71dc5lOucygSPeBHnYyPMVMyC5K5NUAFBEOomY1jo452rN7C5gIRANPOOc22BmvwR8zrkFwNPA895v/5X4P/jx+uUBG4Fa4E7nXB1AU+ts+80LLcWVR/iksIJ7Lh+mAoAi0mm0GhQAzrk3gTcbtf1rg8fH8M8tNLXsr4BfBbLOJvr0CGR8ncU8rwDgN8fraicR6TxCajI7nNXVO+b7irkgO1UFAEWkU1FQdJBPCveys+qYJrFFpNNRUHSQub5ienXrwmU5fYI9FBGRU6Kg6AD7Dlfz7obdXDdOBQBFpPNRUHSAV1erAKCIdF4KinbmnGNufjGjM5I4I10FAEWk81FQtLP1pQfYvOsgM3Q0ISKdlIKinZ0sADimX7CHIiLytSgo2tGxmjpeXV3KlSoAKCKdmIKiHb29fhcHj9Uyc4JOO4lI56WgaEcnCwBmqQCgiHReCop28kXFET79vIKZ4weoAKCIdGoKinYyf0WxCgCKSFhQULSDunrHvBUlTM5OpZ8KAIpIJ6egaAcfF+6lrOoYN2oSW0TCgIKiHeTl+wsAXnqGCgCKSOenoGhjlYereWfjLq4fl6ECgCISFhQUbezVVaXU1DlmTtAktoiEBwVFG3LOkecrZkxGEiP6qgCgiIQHBUUbWldapQKAIhJ2FBRt6GQBwLEqACgi4UNB0UaO1dTx2uqdXDUqncR4FQAUkfChoGgjJwsA6rSTiIQZBUUbmZtfzMDkbpydlRzsoYiItCkFRRvYUXGYpdsqmJmboQKAIhJ2FBRtYP6KEqJUAFBEwpSC4jTV1Tvmryhh8rBU0pNUAFBEwo+C4jR9tLWcsqpjmsQWkbCloDhNeb5ikrvHctkZacEeiohIu1BQnIbKw9W8u3E314/rT2yM/ilFJDwF9OlmZlPNrMDMCs3s3iZejzOzud7ry8wss8Fr93ntBWY2pbV1mtkLXvt6M3vGzEL27rVXThQA1GknEQljrQaFmUUDjwFXAjnAbDPLadTtVmCfc24o8DDwgLdsDjALGAlMBR43s+hW1vkCMAIYBXQFbjutLWwnzjnm+YoZM6Anw/smBHs4IiLtJpAjiolAoXNum3OuGpgDTGvUZxrwnPd4PnCpmZnXPsc5d9w5VwQUeutrdp3OuTedB1gOhOQ1p2tL/AUAZ+aG5PBERNpMIEHRHyhu8LzEa2uyj3OuFqgCUlpYttV1eqecvg283dSgzOx2M/OZma+8vDyAzWhbeb5i4rtEcc0YFQAUkfAWyjOwjwNLnHMfNfWic+4p51yucy43NTW1Qwd2tLqOBat3ctWZKgAoIuEvJoA+pUDD2doMr62pPiVmFgMkARWtLNvsOs3s34BU4P8EML4O9/aGMg4er2XmBE1ii0j4C+SIIh/INrMsM4vFPzm9oFGfBcDN3uPpwPveHMMCYJZ3VVQWkI1/3qHZdZrZbcAUYLZzrv70Nq99zM0vZlCKCgCKSGRo9YjCOVdrZncBC4Fo4Bnn3AYz+yXgc84tAJ4GnjezQqAS/wc/Xr88YCNQC9zpnKsDaGqd3ls+AewAlvrnw3nZOffLNtvi07Sj4jCfbavkJ1OG441PRCSsBXLqCefcm8Cbjdr+tcHjY8CMZpb9FfCrQNbptQc0pmCZ5/MKAJ6lq51EJDKE8mR2yDlRAPDCYan0TYoP9nBERDqEguIULNlazq4DKgAoIpFFQXEK8vKLSekey6UqACgiEURBEaCKQ8d5b5MKAIpI5NEnXoBOFgDUvRMiEmEUFAFwzpHnK2bsgJ4MS1MBQBGJLAqKAKwpqWLL7kOaxBaRiKSgCMCXBQDTgz0UEZEOp6BoxdHqOl5fvZOrRqWToAKAIhKBFBSteGu9vwDgjTrtJCIRSkHRirn5xWSmdGOiCgCKSIRSULRg+97DLCuqZEbuABUAFJGIpaBowbwVxSoAKCIRT0HRjNq6euavKOGi4X1UAFBEIpqCohkfbd3L7gPHmZmrowkRiWwKimbM9QoAXjJCBQBFJLIpKJqgAoAiIl/Sp2ATXllVSm2940YVABQRUVA05pxjbn4x4wb2JFsFAEVEFBSNrS7ez9Y9KgAoInKCgqKRPF8JXbtEc/VoFQAUEQEFxVccqa7l9TUqACgi0pCCooG31u3i0PFaTWKLiDSgoGhgrq+YrN7dmZDZK9hDEREJGQoKT9HewywvqmRGboYKAIqINKCg8MzzqQCgiEhTFBT4CwC+tLKEi4f3IS1RBQBFRBpSUABLtpaz+8BxZujeCRGRv6GgwF8AsHePWC49o0+whyIiEnIiPij2HjrOok17uH5cf7pER/w/h4jI34j4T8ZXVqoAoIhISwIKCjObamYFZlZoZvc28Xqcmc31Xl9mZpkNXrvPay8wsymtrdPMsrx1FHrrjD3NbWyWc448XzFnDezJ0D4qACgi0pRWg8LMooHHgCuBHGC2meU06nYrsM85NxR4GHjAWzYHmAWMBKYCj5tZdCvrfAB42FvXPm/d7WKVCgCKiLQqkCOKiUChc26bc64amANMa9RnGvCc93g+cKn571qbBsxxzh13zhUBhd76mlynt8wl3jrw1nnd1966VszzFfsLAI7p115vISLS6QUSFP2B4gbPS7y2Jvs452qBKiClhWWba08B9nvraO69ADCz283MZ2a+8vLyADbjbw1M7s53zsukR1zM11peRCQSdNpPSOfcU8BTALm5ue7rrOOOi4a06ZhERMJRIEcUpUDDk/gZXluTfcwsBkgCKlpYtrn2CqCnt47m3ktERDpQIEGRD2R7VyPF4p+cXtCozwLgZu/xdOB955zz2md5V0VlAdnA8ubW6S2z2FsH3jpf+/qbJyIip6vVU0/OuVozuwtYCEQDzzjnNpjZLwGfc24B8DTwvJkVApX4P/jx+uUBG4Fa4E7nXB1AU+v03vKfgTlm9h/AKm/dIiISJOb/Jb5zy83NdT6fL9jDEBHpVMxshXMut7V+EX9ntoiItExBISIiLVJQiIhIixQUIiLSorCYzDazcmDH11y8N7C3DYfTGWibI4O2Ofyd7vYOcs6lttYpLILidJiZL5BZ/3CibY4M2ubw11Hbq1NPIiLSIgWFiIi0SEHhFRaMMNrmyKBtDn8dsr0RP0chIiIt0xGFiIi0SEEhIiItiuigMLOpZlZgZoVmdm+wx3MqzGyAmS02s41mtsHM/tFrTzazd81sq/d3L6/dzOxRb1vXmtlZDdZ1s9d/q5nd3KB9vJmt85Z51Puq2qDzvnd9lZm94T3PMrNl3jjneqXr8crbz/Xal5lZZoN13Oe1F5jZlAbtIfczYWY9zWy+mW02s01mdk6472czu9v7uV5vZi+aWXy47Wcze8bM9pjZ+gZt7b5fm3uPFjnnIvIP/vLmnwODgVhgDZAT7HGdwvjTgbO8xwnAFiAHeBC412u/F3jAe3wV8BZgwCRgmdeeDGzz/u7lPe7lvbbc62veslcGe7u9cd0D/AV4w3ueB8zyHj8B3OE9/gHwhPd4FjDXe5zj7e84IMv7OYgO1Z8J/N8df5v3OBboGc77Gf/XHxcBXRvs3++E234GJgNnAesbtLX7fm3uPVoca7D/EwTxh/EcYGGD5/cB9wV7XKexPa8BlwMFQLrXlg4UeI+fBGY36F/gvT4beLJB+5NeWzqwuUH7V/oFcTszgEXAJcAb3n+CvUBM4/2K//tOzvEex3j9rPG+PtEvFH8m8H9bZBHehSeN91847mf8QVHsffjFePt5SjjuZyCTrwZFu+/X5t6jpT+RfOrpxA/jCSVeW6fjHWqPA5YBac65Mu+lXUCa97i57W2pvaSJ9mD7H+CnQL33PAXY75yr9Z43HOfJbfNer/L6n+q/RTBlAeXAH73Tbf9rZt0J4/3snCsFHgK+AMrw77cVhPd+PqEj9mtz79GsSA6KsGBmPYCXgB855w40fM35f2UIm+ufzexqYI9zbkWwx9KBYvCfnvi9c24ccBj/6YKTwnA/9wKm4Q/JfkB3YGpQBxUEHbFfA32PSA6KUmBAg+cZXlunYWZd8IfEC865l73m3WaW7r2eDuzx2pvb3pbaM5poD6bzgGvNbDswB//pp0eAnmZ24mt9G47z5LZ5rycBFZz6v0UwlQAlzrll3vP5+IMjnPfzZUCRc67cOVcDvIx/34fzfj6hI/Zrc+/RrEgOinwg27uSIhb/JNiCII8pYN4VDE8Dm5xzv2nw0gLgxJUPN+OfuzjRfpN39cQkoMo7/FwIXGFmvbzf5K7Af/62DDhgZpO897qpwbqCwjl3n3MuwzmXiX9/ve+c+ztgMTDd69Z4m0/8W0z3+juvfZZ3tUwWkI1/4i/kfiacc7uAYjMb7jVdiv876MN2P+M/5TTJzLp5YzqxzWG7nxvoiP3a3Hs0L5iTVsH+g/9Kgi34r4D4l2CP5xTHfj7+Q8a1wGrvz1X4z80uArYC7wHJXn8DHvO2dR2Q22Bd3wUKvT+3NGjPBdZ7y/yORhOqQd7+i/jyqqfB+D8ACoF5QJzXHu89L/ReH9xg+X/xtquABlf5hOLPBDAW8Hn7+lX8V7eE9X4GfgFs9sb1PP4rl8JqPwMv4p+DqcF/5HhrR+zX5t6jpT8q4SEiIi2K5FNPIiISAAWFiIi0SEEhIiItUlCIiEiLFBQiItIiBYWIiLRIQSEiIi36/zob5nVzA95IAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "lrs=[]\n", + "for i in range(100000):\n", + " sc.step()\n", + " lrs.append(sc.get_lr())\n", + "xs = list(range(100000))\n", + "plt.plot(xs, lrs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e613fe16", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0fd9f40", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/.notebook/u2_model.ipynb b/.notebook/u2_confermer_model_wenet.ipynb similarity index 100% rename from .notebook/u2_model.ipynb rename to .notebook/u2_confermer_model_wenet.ipynb diff --git a/.notebook/u2_tansformer_model_espnet.ipynb b/.notebook/u2_tansformer_model_espnet.ipynb new file mode 100644 index 00000000..75c2ea5c --- /dev/null +++ b/.notebook/u2_tansformer_model_espnet.ipynb @@ -0,0 +1,1672 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "choice-grade", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/DeepSpeech-2.x\n" + ] + }, + { + "data": { + "text/plain": [ + "'/workspace/DeepSpeech-2.x'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%cd ..\n", + "%pwd" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "broke-broad", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", + "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " def convert_to_list(value, n, name, dtype=np.int):\n", + "register user softmax to paddle, remove this when fixed!\n", + "register user log_softmax to paddle, remove this when fixed!\n", + "register user sigmoid to paddle, remove this when fixed!\n", + "register user log_sigmoid to paddle, remove this when fixed!\n", + "register user relu to paddle, remove this when fixed!\n", + "override cat of paddle if exists or register, remove this when fixed!\n", + "override item of paddle.Tensor if exists or register, remove this when fixed!\n", + "override long of paddle.Tensor if exists or register, remove this when fixed!\n", + "override new_full of paddle.Tensor if exists or register, remove this when fixed!\n", + "override eq of paddle.Tensor if exists or register, remove this when fixed!\n", + "override eq of paddle if exists or register, remove this when fixed!\n", + "override contiguous of paddle.Tensor if exists or register, remove this when fixed!\n", + "override size of paddle.Tensor (`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!\n", + "register user view to paddle.Tensor, remove this when fixed!\n", + "register user view_as to paddle.Tensor, remove this when fixed!\n", + "register user masked_fill to paddle.Tensor, remove this when fixed!\n", + "register user masked_fill_ to paddle.Tensor, remove this when fixed!\n", + "register user fill_ to paddle.Tensor, remove this when fixed!\n", + "register user repeat to paddle.Tensor, remove this when fixed!\n", + "register user softmax to paddle.Tensor, remove this when fixed!\n", + "register user sigmoid to paddle.Tensor, remove this when fixed!\n", + "register user relu to paddle.Tensor, remove this when fixed!\n", + "register user type_as to paddle.Tensor, remove this when fixed!\n", + "register user to to paddle.Tensor, remove this when fixed!\n", + "register user float to paddle.Tensor, remove this when fixed!\n", + "register user tolist to paddle.Tensor, remove this when fixed!\n", + "register user glu to paddle.nn.functional, remove this when fixed!\n", + "override ctc_loss of paddle.nn.functional if exists, remove this when fixed!\n", + "register user Module to paddle.nn, remove this when fixed!\n", + "register user ModuleList to paddle.nn, remove this when fixed!\n", + "register user GLU to paddle.nn, remove this when fixed!\n", + "register user ConstantPad2d to paddle.nn, remove this when fixed!\n", + "register user export to paddle.jit, remove this when fixed!\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import paddle\n", + "from yacs.config import CfgNode as CN\n", + "\n", + "from deepspeech.models.u2 import U2Model\n", + "from deepspeech.utils.layer_tools import print_params\n", + "from deepspeech.utils.layer_tools import summary" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "permanent-summary", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", + " and should_run_async(code)\n", + "[INFO 2021/05/31 03:23:22 u2.py:839] U2 Encoder type: transformer\n", + "[INFO 2021/05/31 03:23:22 u2.py:840] attention_dropout_rate: 0.0\n", + "attention_heads: 4\n", + "dropout_rate: 0.1\n", + "input_layer: conv2d\n", + "linear_units: 2048\n", + "normalize_before: True\n", + "num_blocks: 12\n", + "output_size: 256\n", + "positional_dropout_rate: 0.1\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "encoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", + "encoder.embed.conv.0.bias | [256] | 256 | True\n", + "encoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", + "encoder.embed.conv.2.bias | [256] | 256 | True\n", + "encoder.embed.out.0.weight | [5120, 256] | 1310720 | True\n", + "encoder.embed.out.0.bias | [256] | 256 | True\n", + "encoder.after_norm.weight | [256] | 256 | True\n", + "encoder.after_norm.bias | [256] | 256 | True\n", + "encoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.0.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.0.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.0.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.0.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", + "encoder.encoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.1.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.1.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.1.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.1.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.1.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.1.concat_linear.bias | [256] | 256 | True\n", + "encoder.encoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.2.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.2.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.2.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.2.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.2.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.2.concat_linear.bias | [256] | 256 | True\n", + "encoder.encoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.3.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.3.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.3.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.3.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.3.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.3.concat_linear.bias | [256] | 256 | True\n", + "encoder.encoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.4.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.4.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.4.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.4.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.4.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.4.concat_linear.bias | [256] | 256 | True\n", + "encoder.encoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.5.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.5.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.5.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.5.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.5.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.5.concat_linear.bias | [256] | 256 | True\n", + "encoder.encoders.6.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.6.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.6.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.6.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.6.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.6.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.6.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.6.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.6.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.6.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.6.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.6.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.6.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.6.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.6.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.6.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.6.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.6.concat_linear.bias | [256] | 256 | True\n", + "encoder.encoders.7.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.7.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.7.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.7.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.7.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.7.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.7.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.7.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.7.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.7.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.7.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.7.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.7.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.7.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.7.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.7.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.7.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.7.concat_linear.bias | [256] | 256 | True\n", + "encoder.encoders.8.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.8.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.8.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.8.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.8.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.8.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.8.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.8.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.8.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.8.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.8.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.8.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.8.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.8.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.8.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.8.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.8.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.8.concat_linear.bias | [256] | 256 | True\n", + "encoder.encoders.9.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.9.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.9.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.9.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.9.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.9.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.9.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.9.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.9.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.9.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.9.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.9.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.9.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.9.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.9.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.9.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.9.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.9.concat_linear.bias | [256] | 256 | True\n", + "encoder.encoders.10.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.10.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.10.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.10.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.10.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.10.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.10.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.10.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.10.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.10.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.10.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.10.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.10.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.10.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.10.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.10.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.10.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.10.concat_linear.bias | [256] | 256 | True\n", + "encoder.encoders.11.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.11.self_attn.linear_q.bias | [256] | 256 | True\n", + "encoder.encoders.11.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.11.self_attn.linear_k.bias | [256] | 256 | True\n", + "encoder.encoders.11.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.11.self_attn.linear_v.bias | [256] | 256 | True\n", + "encoder.encoders.11.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "encoder.encoders.11.self_attn.linear_out.bias | [256] | 256 | True\n", + "encoder.encoders.11.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "encoder.encoders.11.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "encoder.encoders.11.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "encoder.encoders.11.feed_forward.w_2.bias | [256] | 256 | True\n", + "encoder.encoders.11.norm1.weight | [256] | 256 | True\n", + "encoder.encoders.11.norm1.bias | [256] | 256 | True\n", + "encoder.encoders.11.norm2.weight | [256] | 256 | True\n", + "encoder.encoders.11.norm2.bias | [256] | 256 | True\n", + "encoder.encoders.11.concat_linear.weight | [512, 256] | 131072 | True\n", + "encoder.encoders.11.concat_linear.bias | [256] | 256 | True\n", + "decoder.embed.0.weight | [4233, 256] | 1083648 | True\n", + "decoder.after_norm.weight | [256] | 256 | True\n", + "decoder.after_norm.bias | [256] | 256 | True\n", + "decoder.output_layer.weight | [256, 4233] | 1083648 | True\n", + "decoder.output_layer.bias | [4233] | 4233 | True\n", + "decoder.decoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.0.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.0.src_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.0.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.0.src_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.0.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.0.src_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.0.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.0.src_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "decoder.decoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "decoder.decoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "decoder.decoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", + "decoder.decoders.0.norm1.weight | [256] | 256 | True\n", + "decoder.decoders.0.norm1.bias | [256] | 256 | True\n", + "decoder.decoders.0.norm2.weight | [256] | 256 | True\n", + "decoder.decoders.0.norm2.bias | [256] | 256 | True\n", + "decoder.decoders.0.norm3.weight | [256] | 256 | True\n", + "decoder.decoders.0.norm3.bias | [256] | 256 | True\n", + "decoder.decoders.0.concat_linear1.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.0.concat_linear1.bias | [256] | 256 | True\n", + "decoder.decoders.0.concat_linear2.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.0.concat_linear2.bias | [256] | 256 | True\n", + "decoder.decoders.1.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.1.self_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.1.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.1.self_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.1.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.1.self_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.1.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.1.self_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.1.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.1.src_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.1.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.1.src_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.1.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.1.src_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.1.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.1.src_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.1.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "decoder.decoders.1.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "decoder.decoders.1.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "decoder.decoders.1.feed_forward.w_2.bias | [256] | 256 | True\n", + "decoder.decoders.1.norm1.weight | [256] | 256 | True\n", + "decoder.decoders.1.norm1.bias | [256] | 256 | True\n", + "decoder.decoders.1.norm2.weight | [256] | 256 | True\n", + "decoder.decoders.1.norm2.bias | [256] | 256 | True\n", + "decoder.decoders.1.norm3.weight | [256] | 256 | True\n", + "decoder.decoders.1.norm3.bias | [256] | 256 | True\n", + "decoder.decoders.1.concat_linear1.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.1.concat_linear1.bias | [256] | 256 | True\n", + "decoder.decoders.1.concat_linear2.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.1.concat_linear2.bias | [256] | 256 | True\n", + "decoder.decoders.2.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.2.self_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.2.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.2.self_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.2.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.2.self_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.2.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.2.self_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.2.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.2.src_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.2.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.2.src_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.2.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.2.src_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.2.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.2.src_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.2.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "decoder.decoders.2.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "decoder.decoders.2.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "decoder.decoders.2.feed_forward.w_2.bias | [256] | 256 | True\n", + "decoder.decoders.2.norm1.weight | [256] | 256 | True\n", + "decoder.decoders.2.norm1.bias | [256] | 256 | True\n", + "decoder.decoders.2.norm2.weight | [256] | 256 | True\n", + "decoder.decoders.2.norm2.bias | [256] | 256 | True\n", + "decoder.decoders.2.norm3.weight | [256] | 256 | True\n", + "decoder.decoders.2.norm3.bias | [256] | 256 | True\n", + "decoder.decoders.2.concat_linear1.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.2.concat_linear1.bias | [256] | 256 | True\n", + "decoder.decoders.2.concat_linear2.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.2.concat_linear2.bias | [256] | 256 | True\n", + "decoder.decoders.3.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.3.self_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.3.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.3.self_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.3.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.3.self_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.3.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.3.self_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.3.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.3.src_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.3.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.3.src_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.3.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.3.src_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.3.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.3.src_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.3.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "decoder.decoders.3.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "decoder.decoders.3.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "decoder.decoders.3.feed_forward.w_2.bias | [256] | 256 | True\n", + "decoder.decoders.3.norm1.weight | [256] | 256 | True\n", + "decoder.decoders.3.norm1.bias | [256] | 256 | True\n", + "decoder.decoders.3.norm2.weight | [256] | 256 | True\n", + "decoder.decoders.3.norm2.bias | [256] | 256 | True\n", + "decoder.decoders.3.norm3.weight | [256] | 256 | True\n", + "decoder.decoders.3.norm3.bias | [256] | 256 | True\n", + "decoder.decoders.3.concat_linear1.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.3.concat_linear1.bias | [256] | 256 | True\n", + "decoder.decoders.3.concat_linear2.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.3.concat_linear2.bias | [256] | 256 | True\n", + "decoder.decoders.4.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.4.self_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.4.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.4.self_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.4.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.4.self_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.4.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.4.self_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.4.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.4.src_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.4.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.4.src_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.4.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.4.src_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.4.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.4.src_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.4.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "decoder.decoders.4.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "decoder.decoders.4.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "decoder.decoders.4.feed_forward.w_2.bias | [256] | 256 | True\n", + "decoder.decoders.4.norm1.weight | [256] | 256 | True\n", + "decoder.decoders.4.norm1.bias | [256] | 256 | True\n", + "decoder.decoders.4.norm2.weight | [256] | 256 | True\n", + "decoder.decoders.4.norm2.bias | [256] | 256 | True\n", + "decoder.decoders.4.norm3.weight | [256] | 256 | True\n", + "decoder.decoders.4.norm3.bias | [256] | 256 | True\n", + "decoder.decoders.4.concat_linear1.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.4.concat_linear1.bias | [256] | 256 | True\n", + "decoder.decoders.4.concat_linear2.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.4.concat_linear2.bias | [256] | 256 | True\n", + "decoder.decoders.5.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.5.self_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.5.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.5.self_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.5.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.5.self_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.5.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.5.self_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.5.src_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.5.src_attn.linear_q.bias | [256] | 256 | True\n", + "decoder.decoders.5.src_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.5.src_attn.linear_k.bias | [256] | 256 | True\n", + "decoder.decoders.5.src_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.5.src_attn.linear_v.bias | [256] | 256 | True\n", + "decoder.decoders.5.src_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "decoder.decoders.5.src_attn.linear_out.bias | [256] | 256 | True\n", + "decoder.decoders.5.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "decoder.decoders.5.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "decoder.decoders.5.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "decoder.decoders.5.feed_forward.w_2.bias | [256] | 256 | True\n", + "decoder.decoders.5.norm1.weight | [256] | 256 | True\n", + "decoder.decoders.5.norm1.bias | [256] | 256 | True\n", + "decoder.decoders.5.norm2.weight | [256] | 256 | True\n", + "decoder.decoders.5.norm2.bias | [256] | 256 | True\n", + "decoder.decoders.5.norm3.weight | [256] | 256 | True\n", + "decoder.decoders.5.norm3.bias | [256] | 256 | True\n", + "decoder.decoders.5.concat_linear1.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.5.concat_linear1.bias | [256] | 256 | True\n", + "decoder.decoders.5.concat_linear2.weight | [512, 256] | 131072 | True\n", + "decoder.decoders.5.concat_linear2.bias | [256] | 256 | True\n", + "ctc.ctc_lo.weight | [256, 4233] | 1083648 | True\n", + "ctc.ctc_lo.bias | [4233] | 4233 | True\n", + "Total parameters: 411.0, 32.01M elements.\n" + ] + } + ], + "source": [ + "conf_str='examples/tiny/s1/conf/transformer.yaml'\n", + "cfg = CN().load_cfg(open(conf_str))\n", + "cfg.model.input_dim = 83\n", + "cfg.model.output_dim = 4233\n", + "cfg.model.cmvn_file = None\n", + "cfg.model.cmvn_file_type = 'json'\n", + "#cfg.model.encoder_conf.concat_after=True\n", + "cfg.freeze()\n", + "model = U2Model(cfg.model)\n", + "\n", + "print_params(model)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "sapphire-agent", + "metadata": {}, + "outputs": [], + "source": [ + "#summary(model)\n", + "#print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ruled-invitation", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fossil-means", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n", + "encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n" + ] + } + ], + "source": [ + "%ls .notebook/espnet" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "45c2b75f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "state\n", + "odict_keys(['mask_feature', 'encoder.embed.conv.0.weight', 'encoder.embed.conv.0.bias', 'encoder.embed.conv.2.weight', 'encoder.embed.conv.2.bias', 'encoder.embed.out.0.weight', 'encoder.embed.out.0.bias', 'encoder.encoders.0.self_attn.linear_q.weight', 'encoder.encoders.0.self_attn.linear_q.bias', 'encoder.encoders.0.self_attn.linear_k.weight', 'encoder.encoders.0.self_attn.linear_k.bias', 'encoder.encoders.0.self_attn.linear_v.weight', 'encoder.encoders.0.self_attn.linear_v.bias', 'encoder.encoders.0.self_attn.linear_out.weight', 'encoder.encoders.0.self_attn.linear_out.bias', 'encoder.encoders.0.feed_forward.w_1.weight', 'encoder.encoders.0.feed_forward.w_1.bias', 'encoder.encoders.0.feed_forward.w_2.weight', 'encoder.encoders.0.feed_forward.w_2.bias', 'encoder.encoders.0.norm1.weight', 'encoder.encoders.0.norm1.bias', 'encoder.encoders.0.norm2.weight', 'encoder.encoders.0.norm2.bias', 'encoder.encoders.1.self_attn.linear_q.weight', 'encoder.encoders.1.self_attn.linear_q.bias', 'encoder.encoders.1.self_attn.linear_k.weight', 'encoder.encoders.1.self_attn.linear_k.bias', 'encoder.encoders.1.self_attn.linear_v.weight', 'encoder.encoders.1.self_attn.linear_v.bias', 'encoder.encoders.1.self_attn.linear_out.weight', 'encoder.encoders.1.self_attn.linear_out.bias', 'encoder.encoders.1.feed_forward.w_1.weight', 'encoder.encoders.1.feed_forward.w_1.bias', 'encoder.encoders.1.feed_forward.w_2.weight', 'encoder.encoders.1.feed_forward.w_2.bias', 'encoder.encoders.1.norm1.weight', 'encoder.encoders.1.norm1.bias', 'encoder.encoders.1.norm2.weight', 'encoder.encoders.1.norm2.bias', 'encoder.encoders.2.self_attn.linear_q.weight', 'encoder.encoders.2.self_attn.linear_q.bias', 'encoder.encoders.2.self_attn.linear_k.weight', 'encoder.encoders.2.self_attn.linear_k.bias', 'encoder.encoders.2.self_attn.linear_v.weight', 'encoder.encoders.2.self_attn.linear_v.bias', 'encoder.encoders.2.self_attn.linear_out.weight', 'encoder.encoders.2.self_attn.linear_out.bias', 'encoder.encoders.2.feed_forward.w_1.weight', 'encoder.encoders.2.feed_forward.w_1.bias', 'encoder.encoders.2.feed_forward.w_2.weight', 'encoder.encoders.2.feed_forward.w_2.bias', 'encoder.encoders.2.norm1.weight', 'encoder.encoders.2.norm1.bias', 'encoder.encoders.2.norm2.weight', 'encoder.encoders.2.norm2.bias', 'encoder.encoders.3.self_attn.linear_q.weight', 'encoder.encoders.3.self_attn.linear_q.bias', 'encoder.encoders.3.self_attn.linear_k.weight', 'encoder.encoders.3.self_attn.linear_k.bias', 'encoder.encoders.3.self_attn.linear_v.weight', 'encoder.encoders.3.self_attn.linear_v.bias', 'encoder.encoders.3.self_attn.linear_out.weight', 'encoder.encoders.3.self_attn.linear_out.bias', 'encoder.encoders.3.feed_forward.w_1.weight', 'encoder.encoders.3.feed_forward.w_1.bias', 'encoder.encoders.3.feed_forward.w_2.weight', 'encoder.encoders.3.feed_forward.w_2.bias', 'encoder.encoders.3.norm1.weight', 'encoder.encoders.3.norm1.bias', 'encoder.encoders.3.norm2.weight', 'encoder.encoders.3.norm2.bias', 'encoder.encoders.4.self_attn.linear_q.weight', 'encoder.encoders.4.self_attn.linear_q.bias', 'encoder.encoders.4.self_attn.linear_k.weight', 'encoder.encoders.4.self_attn.linear_k.bias', 'encoder.encoders.4.self_attn.linear_v.weight', 'encoder.encoders.4.self_attn.linear_v.bias', 'encoder.encoders.4.self_attn.linear_out.weight', 'encoder.encoders.4.self_attn.linear_out.bias', 'encoder.encoders.4.feed_forward.w_1.weight', 'encoder.encoders.4.feed_forward.w_1.bias', 'encoder.encoders.4.feed_forward.w_2.weight', 'encoder.encoders.4.feed_forward.w_2.bias', 'encoder.encoders.4.norm1.weight', 'encoder.encoders.4.norm1.bias', 'encoder.encoders.4.norm2.weight', 'encoder.encoders.4.norm2.bias', 'encoder.encoders.5.self_attn.linear_q.weight', 'encoder.encoders.5.self_attn.linear_q.bias', 'encoder.encoders.5.self_attn.linear_k.weight', 'encoder.encoders.5.self_attn.linear_k.bias', 'encoder.encoders.5.self_attn.linear_v.weight', 'encoder.encoders.5.self_attn.linear_v.bias', 'encoder.encoders.5.self_attn.linear_out.weight', 'encoder.encoders.5.self_attn.linear_out.bias', 'encoder.encoders.5.feed_forward.w_1.weight', 'encoder.encoders.5.feed_forward.w_1.bias', 'encoder.encoders.5.feed_forward.w_2.weight', 'encoder.encoders.5.feed_forward.w_2.bias', 'encoder.encoders.5.norm1.weight', 'encoder.encoders.5.norm1.bias', 'encoder.encoders.5.norm2.weight', 'encoder.encoders.5.norm2.bias', 'encoder.encoders.6.self_attn.linear_q.weight', 'encoder.encoders.6.self_attn.linear_q.bias', 'encoder.encoders.6.self_attn.linear_k.weight', 'encoder.encoders.6.self_attn.linear_k.bias', 'encoder.encoders.6.self_attn.linear_v.weight', 'encoder.encoders.6.self_attn.linear_v.bias', 'encoder.encoders.6.self_attn.linear_out.weight', 'encoder.encoders.6.self_attn.linear_out.bias', 'encoder.encoders.6.feed_forward.w_1.weight', 'encoder.encoders.6.feed_forward.w_1.bias', 'encoder.encoders.6.feed_forward.w_2.weight', 'encoder.encoders.6.feed_forward.w_2.bias', 'encoder.encoders.6.norm1.weight', 'encoder.encoders.6.norm1.bias', 'encoder.encoders.6.norm2.weight', 'encoder.encoders.6.norm2.bias', 'encoder.encoders.7.self_attn.linear_q.weight', 'encoder.encoders.7.self_attn.linear_q.bias', 'encoder.encoders.7.self_attn.linear_k.weight', 'encoder.encoders.7.self_attn.linear_k.bias', 'encoder.encoders.7.self_attn.linear_v.weight', 'encoder.encoders.7.self_attn.linear_v.bias', 'encoder.encoders.7.self_attn.linear_out.weight', 'encoder.encoders.7.self_attn.linear_out.bias', 'encoder.encoders.7.feed_forward.w_1.weight', 'encoder.encoders.7.feed_forward.w_1.bias', 'encoder.encoders.7.feed_forward.w_2.weight', 'encoder.encoders.7.feed_forward.w_2.bias', 'encoder.encoders.7.norm1.weight', 'encoder.encoders.7.norm1.bias', 'encoder.encoders.7.norm2.weight', 'encoder.encoders.7.norm2.bias', 'encoder.encoders.8.self_attn.linear_q.weight', 'encoder.encoders.8.self_attn.linear_q.bias', 'encoder.encoders.8.self_attn.linear_k.weight', 'encoder.encoders.8.self_attn.linear_k.bias', 'encoder.encoders.8.self_attn.linear_v.weight', 'encoder.encoders.8.self_attn.linear_v.bias', 'encoder.encoders.8.self_attn.linear_out.weight', 'encoder.encoders.8.self_attn.linear_out.bias', 'encoder.encoders.8.feed_forward.w_1.weight', 'encoder.encoders.8.feed_forward.w_1.bias', 'encoder.encoders.8.feed_forward.w_2.weight', 'encoder.encoders.8.feed_forward.w_2.bias', 'encoder.encoders.8.norm1.weight', 'encoder.encoders.8.norm1.bias', 'encoder.encoders.8.norm2.weight', 'encoder.encoders.8.norm2.bias', 'encoder.encoders.9.self_attn.linear_q.weight', 'encoder.encoders.9.self_attn.linear_q.bias', 'encoder.encoders.9.self_attn.linear_k.weight', 'encoder.encoders.9.self_attn.linear_k.bias', 'encoder.encoders.9.self_attn.linear_v.weight', 'encoder.encoders.9.self_attn.linear_v.bias', 'encoder.encoders.9.self_attn.linear_out.weight', 'encoder.encoders.9.self_attn.linear_out.bias', 'encoder.encoders.9.feed_forward.w_1.weight', 'encoder.encoders.9.feed_forward.w_1.bias', 'encoder.encoders.9.feed_forward.w_2.weight', 'encoder.encoders.9.feed_forward.w_2.bias', 'encoder.encoders.9.norm1.weight', 'encoder.encoders.9.norm1.bias', 'encoder.encoders.9.norm2.weight', 'encoder.encoders.9.norm2.bias', 'encoder.encoders.10.self_attn.linear_q.weight', 'encoder.encoders.10.self_attn.linear_q.bias', 'encoder.encoders.10.self_attn.linear_k.weight', 'encoder.encoders.10.self_attn.linear_k.bias', 'encoder.encoders.10.self_attn.linear_v.weight', 'encoder.encoders.10.self_attn.linear_v.bias', 'encoder.encoders.10.self_attn.linear_out.weight', 'encoder.encoders.10.self_attn.linear_out.bias', 'encoder.encoders.10.feed_forward.w_1.weight', 'encoder.encoders.10.feed_forward.w_1.bias', 'encoder.encoders.10.feed_forward.w_2.weight', 'encoder.encoders.10.feed_forward.w_2.bias', 'encoder.encoders.10.norm1.weight', 'encoder.encoders.10.norm1.bias', 'encoder.encoders.10.norm2.weight', 'encoder.encoders.10.norm2.bias', 'encoder.encoders.11.self_attn.linear_q.weight', 'encoder.encoders.11.self_attn.linear_q.bias', 'encoder.encoders.11.self_attn.linear_k.weight', 'encoder.encoders.11.self_attn.linear_k.bias', 'encoder.encoders.11.self_attn.linear_v.weight', 'encoder.encoders.11.self_attn.linear_v.bias', 'encoder.encoders.11.self_attn.linear_out.weight', 'encoder.encoders.11.self_attn.linear_out.bias', 'encoder.encoders.11.feed_forward.w_1.weight', 'encoder.encoders.11.feed_forward.w_1.bias', 'encoder.encoders.11.feed_forward.w_2.weight', 'encoder.encoders.11.feed_forward.w_2.bias', 'encoder.encoders.11.norm1.weight', 'encoder.encoders.11.norm1.bias', 'encoder.encoders.11.norm2.weight', 'encoder.encoders.11.norm2.bias', 'encoder.after_norm.weight', 'encoder.after_norm.bias', 'decoder.embed.0.weight', 'decoder.decoders.0.self_attn.linear_q.weight', 'decoder.decoders.0.self_attn.linear_q.bias', 'decoder.decoders.0.self_attn.linear_k.weight', 'decoder.decoders.0.self_attn.linear_k.bias', 'decoder.decoders.0.self_attn.linear_v.weight', 'decoder.decoders.0.self_attn.linear_v.bias', 'decoder.decoders.0.self_attn.linear_out.weight', 'decoder.decoders.0.self_attn.linear_out.bias', 'decoder.decoders.0.src_attn.linear_q.weight', 'decoder.decoders.0.src_attn.linear_q.bias', 'decoder.decoders.0.src_attn.linear_k.weight', 'decoder.decoders.0.src_attn.linear_k.bias', 'decoder.decoders.0.src_attn.linear_v.weight', 'decoder.decoders.0.src_attn.linear_v.bias', 'decoder.decoders.0.src_attn.linear_out.weight', 'decoder.decoders.0.src_attn.linear_out.bias', 'decoder.decoders.0.feed_forward.w_1.weight', 'decoder.decoders.0.feed_forward.w_1.bias', 'decoder.decoders.0.feed_forward.w_2.weight', 'decoder.decoders.0.feed_forward.w_2.bias', 'decoder.decoders.0.norm1.weight', 'decoder.decoders.0.norm1.bias', 'decoder.decoders.0.norm2.weight', 'decoder.decoders.0.norm2.bias', 'decoder.decoders.0.norm3.weight', 'decoder.decoders.0.norm3.bias', 'decoder.after_norm.weight', 'decoder.after_norm.bias', 'decoder.output_layer.weight', 'decoder.output_layer.bias', 'sfc.weight', 'sfc.bias', 'deconv.0.weight', 'deconv.0.bias', 'deconv.1.weight', 'deconv.1.bias', 'xlm_embed.0.weight', 'xlm_pred.weight', 'xlm_pred.bias'])\n" + ] + } + ], + "source": [ + "#!pip install torch\n", + "import torch\n", + "\n", + "e_model = np.load('.notebook/espnet/model.npz',allow_pickle=True)\n", + "for k in e_model.files:\n", + " print(k)\n", + "state_dict = e_model['state']\n", + "state_dict = state_dict.tolist()\n", + "print(state_dict.keys())" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f187bb55", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", + " and should_run_async(code)\n" + ] + } + ], + "source": [ + "# embed.conv.0.weight None torch.Size([256, 1, 3, 3]) \tencoder.embed.conv.0.weight | [256, 1, 3, 3] | 2304 | True\n", + "# embed.conv.0.bias None torch.Size([256]) \tencoder.embed.conv.0.bias | [256] | 256 | True\n", + "# embed.conv.2.weight None torch.Size([256, 256, 3, 3]) \tencoder.embed.conv.2.weight | [256, 256, 3, 3] | 589824 | True\n", + "# embed.conv.2.bias None torch.Size([256]) \tencoder.embed.conv.2.bias | [256] | 256 | True\n", + "# embed.out.0.weight None torch.Size([256, 5120]) 83 feature\tencoder.embed.out.0.weight | [4864, 256] | 1245184 | True 80 feature\n", + "# embed.out.0.bias None torch.Size([256]) \tencoder.embed.out.0.bias | [256] | 256 | True\n", + "# after_norm.weight None torch.Size([256]) \tencoder.after_norm.weight | [256] | 256 | True\n", + "# after_norm.bias None torch.Size([256]) \tencoder.after_norm.bias | [256] | 256 | True\n", + "# encoders.9.self_attn.linear_q.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_q.weight | [256, 256] | 65536 | True\n", + "# encoders.9.self_attn.linear_q.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_q.bias | [256] | 256 | True\n", + "# encoders.9.self_attn.linear_k.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_k.weight | [256, 256] | 65536 | True\n", + "# encoders.9.self_attn.linear_k.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_k.bias | [256] | 256 | True\n", + "# encoders.9.self_attn.linear_v.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_v.weight | [256, 256] | 65536 | True\n", + "# encoders.9.self_attn.linear_v.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_v.bias | [256] | 256 | True\n", + "# encoders.9.self_attn.linear_out.weight None torch.Size([256, 256]) \tencoder.encoders.0.self_attn.linear_out.weight | [256, 256] | 65536 | True\n", + "# encoders.9.self_attn.linear_out.bias None torch.Size([256]) \tencoder.encoders.0.self_attn.linear_out.bias | [256] | 256 | True\n", + "# encoders.9.feed_forward.w_1.weight None torch.Size([2048, 256]) \tencoder.encoders.0.feed_forward.w_1.weight | [256, 2048] | 524288 | True\n", + "# encoders.9.feed_forward.w_1.bias None torch.Size([2048]) \tencoder.encoders.0.feed_forward.w_1.bias | [2048] | 2048 | True\n", + "# encoders.9.feed_forward.w_2.weight None torch.Size([256, 2048]) \tencoder.encoders.0.feed_forward.w_2.weight | [2048, 256] | 524288 | True\n", + "# encoders.9.feed_forward.w_2.bias None torch.Size([256]) \tencoder.encoders.0.feed_forward.w_2.bias | [256] | 256 | True\n", + "# encoders.9.norm1.weight None torch.Size([256]) \tencoder.encoders.0.norm1.weight | [256] | 256 | True\n", + "# encoders.9.norm1.bias None torch.Size([256]) \tencoder.encoders.0.norm1.bias | [256] | 256 | True\n", + "# encoders.9.norm2.weight None torch.Size([256]) \tencoder.encoders.0.norm2.weight | [256] | 256 | True\n", + "# encoders.9.norm2.bias None torch.Size([256]) \tencoder.encoders.0.norm2.bias | [256] | 256 | True\n", + "# \tencoder.encoders.0.concat_linear.weight | [512, 256] | 131072 | True\n", + "# \tencoder.encoders.0.concat_linear.bias | [256] | 256 | True\n", + "# espnet transformer\tconcat_linear只是保存了,但是未使用\n", + "\t\n", + "# \tpaddle transformer" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2a0428ae", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-> encoder.embed.conv.0.weight\n", + "-> encoder.embed.conv.0.bias\n", + "-> encoder.embed.conv.2.weight\n", + "-> encoder.embed.conv.2.bias\n", + "-> encoder.embed.out.0.weight\n", + "encoder.embed.out.0.weight: (256, 5120) -> (5120, 256)\n", + "-> encoder.embed.out.0.bias\n", + "-> encoder.encoders.0.self_attn.linear_q.weight\n", + "encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.0.self_attn.linear_q.bias\n", + "-> encoder.encoders.0.self_attn.linear_k.weight\n", + "encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.0.self_attn.linear_k.bias\n", + "-> encoder.encoders.0.self_attn.linear_v.weight\n", + "encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.0.self_attn.linear_v.bias\n", + "-> encoder.encoders.0.self_attn.linear_out.weight\n", + "encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.0.self_attn.linear_out.bias\n", + "-> encoder.encoders.0.feed_forward.w_1.weight\n", + "encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.0.feed_forward.w_1.bias\n", + "-> encoder.encoders.0.feed_forward.w_2.weight\n", + "encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.0.feed_forward.w_2.bias\n", + "-> encoder.encoders.0.norm1.weight\n", + "-> encoder.encoders.0.norm1.bias\n", + "-> encoder.encoders.0.norm2.weight\n", + "-> encoder.encoders.0.norm2.bias\n", + "-> encoder.encoders.1.self_attn.linear_q.weight\n", + "encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.1.self_attn.linear_q.bias\n", + "-> encoder.encoders.1.self_attn.linear_k.weight\n", + "encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.1.self_attn.linear_k.bias\n", + "-> encoder.encoders.1.self_attn.linear_v.weight\n", + "encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.1.self_attn.linear_v.bias\n", + "-> encoder.encoders.1.self_attn.linear_out.weight\n", + "encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.1.self_attn.linear_out.bias\n", + "-> encoder.encoders.1.feed_forward.w_1.weight\n", + "encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.1.feed_forward.w_1.bias\n", + "-> encoder.encoders.1.feed_forward.w_2.weight\n", + "encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.1.feed_forward.w_2.bias\n", + "-> encoder.encoders.1.norm1.weight\n", + "-> encoder.encoders.1.norm1.bias\n", + "-> encoder.encoders.1.norm2.weight\n", + "-> encoder.encoders.1.norm2.bias\n", + "-> encoder.encoders.2.self_attn.linear_q.weight\n", + "encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.2.self_attn.linear_q.bias\n", + "-> encoder.encoders.2.self_attn.linear_k.weight\n", + "encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.2.self_attn.linear_k.bias\n", + "-> encoder.encoders.2.self_attn.linear_v.weight\n", + "encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.2.self_attn.linear_v.bias\n", + "-> encoder.encoders.2.self_attn.linear_out.weight\n", + "encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.2.self_attn.linear_out.bias\n", + "-> encoder.encoders.2.feed_forward.w_1.weight\n", + "encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.2.feed_forward.w_1.bias\n", + "-> encoder.encoders.2.feed_forward.w_2.weight\n", + "encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.2.feed_forward.w_2.bias\n", + "-> encoder.encoders.2.norm1.weight\n", + "-> encoder.encoders.2.norm1.bias\n", + "-> encoder.encoders.2.norm2.weight\n", + "-> encoder.encoders.2.norm2.bias\n", + "-> encoder.encoders.3.self_attn.linear_q.weight\n", + "encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.3.self_attn.linear_q.bias\n", + "-> encoder.encoders.3.self_attn.linear_k.weight\n", + "encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.3.self_attn.linear_k.bias\n", + "-> encoder.encoders.3.self_attn.linear_v.weight\n", + "encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.3.self_attn.linear_v.bias\n", + "-> encoder.encoders.3.self_attn.linear_out.weight\n", + "encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.3.self_attn.linear_out.bias\n", + "-> encoder.encoders.3.feed_forward.w_1.weight\n", + "encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.3.feed_forward.w_1.bias\n", + "-> encoder.encoders.3.feed_forward.w_2.weight\n", + "encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.3.feed_forward.w_2.bias\n", + "-> encoder.encoders.3.norm1.weight\n", + "-> encoder.encoders.3.norm1.bias\n", + "-> encoder.encoders.3.norm2.weight\n", + "-> encoder.encoders.3.norm2.bias\n", + "-> encoder.encoders.4.self_attn.linear_q.weight\n", + "encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.4.self_attn.linear_q.bias\n", + "-> encoder.encoders.4.self_attn.linear_k.weight\n", + "encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.4.self_attn.linear_k.bias\n", + "-> encoder.encoders.4.self_attn.linear_v.weight\n", + "encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.4.self_attn.linear_v.bias\n", + "-> encoder.encoders.4.self_attn.linear_out.weight\n", + "encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.4.self_attn.linear_out.bias\n", + "-> encoder.encoders.4.feed_forward.w_1.weight\n", + "encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.4.feed_forward.w_1.bias\n", + "-> encoder.encoders.4.feed_forward.w_2.weight\n", + "encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.4.feed_forward.w_2.bias\n", + "-> encoder.encoders.4.norm1.weight\n", + "-> encoder.encoders.4.norm1.bias\n", + "-> encoder.encoders.4.norm2.weight\n", + "-> encoder.encoders.4.norm2.bias\n", + "-> encoder.encoders.5.self_attn.linear_q.weight\n", + "encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.5.self_attn.linear_q.bias\n", + "-> encoder.encoders.5.self_attn.linear_k.weight\n", + "encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.5.self_attn.linear_k.bias\n", + "-> encoder.encoders.5.self_attn.linear_v.weight\n", + "encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.5.self_attn.linear_v.bias\n", + "-> encoder.encoders.5.self_attn.linear_out.weight\n", + "encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.5.self_attn.linear_out.bias\n", + "-> encoder.encoders.5.feed_forward.w_1.weight\n", + "encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.5.feed_forward.w_1.bias\n", + "-> encoder.encoders.5.feed_forward.w_2.weight\n", + "encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.5.feed_forward.w_2.bias\n", + "-> encoder.encoders.5.norm1.weight\n", + "-> encoder.encoders.5.norm1.bias\n", + "-> encoder.encoders.5.norm2.weight\n", + "-> encoder.encoders.5.norm2.bias\n", + "-> encoder.encoders.6.self_attn.linear_q.weight\n", + "encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.6.self_attn.linear_q.bias\n", + "-> encoder.encoders.6.self_attn.linear_k.weight\n", + "encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.6.self_attn.linear_k.bias\n", + "-> encoder.encoders.6.self_attn.linear_v.weight\n", + "encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.6.self_attn.linear_v.bias\n", + "-> encoder.encoders.6.self_attn.linear_out.weight\n", + "encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.6.self_attn.linear_out.bias\n", + "-> encoder.encoders.6.feed_forward.w_1.weight\n", + "encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.6.feed_forward.w_1.bias\n", + "-> encoder.encoders.6.feed_forward.w_2.weight\n", + "encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.6.feed_forward.w_2.bias\n", + "-> encoder.encoders.6.norm1.weight\n", + "-> encoder.encoders.6.norm1.bias\n", + "-> encoder.encoders.6.norm2.weight\n", + "-> encoder.encoders.6.norm2.bias\n", + "-> encoder.encoders.7.self_attn.linear_q.weight\n", + "encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.7.self_attn.linear_q.bias\n", + "-> encoder.encoders.7.self_attn.linear_k.weight\n", + "encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.7.self_attn.linear_k.bias\n", + "-> encoder.encoders.7.self_attn.linear_v.weight\n", + "encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.7.self_attn.linear_v.bias\n", + "-> encoder.encoders.7.self_attn.linear_out.weight\n", + "encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.7.self_attn.linear_out.bias\n", + "-> encoder.encoders.7.feed_forward.w_1.weight\n", + "encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.7.feed_forward.w_1.bias\n", + "-> encoder.encoders.7.feed_forward.w_2.weight\n", + "encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.7.feed_forward.w_2.bias\n", + "-> encoder.encoders.7.norm1.weight\n", + "-> encoder.encoders.7.norm1.bias\n", + "-> encoder.encoders.7.norm2.weight\n", + "-> encoder.encoders.7.norm2.bias\n", + "-> encoder.encoders.8.self_attn.linear_q.weight\n", + "encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.8.self_attn.linear_q.bias\n", + "-> encoder.encoders.8.self_attn.linear_k.weight\n", + "encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.8.self_attn.linear_k.bias\n", + "-> encoder.encoders.8.self_attn.linear_v.weight\n", + "encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.8.self_attn.linear_v.bias\n", + "-> encoder.encoders.8.self_attn.linear_out.weight\n", + "encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.8.self_attn.linear_out.bias\n", + "-> encoder.encoders.8.feed_forward.w_1.weight\n", + "encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.8.feed_forward.w_1.bias\n", + "-> encoder.encoders.8.feed_forward.w_2.weight\n", + "encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.8.feed_forward.w_2.bias\n", + "-> encoder.encoders.8.norm1.weight\n", + "-> encoder.encoders.8.norm1.bias\n", + "-> encoder.encoders.8.norm2.weight\n", + "-> encoder.encoders.8.norm2.bias\n", + "-> encoder.encoders.9.self_attn.linear_q.weight\n", + "encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.9.self_attn.linear_q.bias\n", + "-> encoder.encoders.9.self_attn.linear_k.weight\n", + "encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.9.self_attn.linear_k.bias\n", + "-> encoder.encoders.9.self_attn.linear_v.weight\n", + "encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.9.self_attn.linear_v.bias\n", + "-> encoder.encoders.9.self_attn.linear_out.weight\n", + "encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.9.self_attn.linear_out.bias\n", + "-> encoder.encoders.9.feed_forward.w_1.weight\n", + "encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.9.feed_forward.w_1.bias\n", + "-> encoder.encoders.9.feed_forward.w_2.weight\n", + "encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.9.feed_forward.w_2.bias\n", + "-> encoder.encoders.9.norm1.weight\n", + "-> encoder.encoders.9.norm1.bias\n", + "-> encoder.encoders.9.norm2.weight\n", + "-> encoder.encoders.9.norm2.bias\n", + "-> encoder.encoders.10.self_attn.linear_q.weight\n", + "encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.10.self_attn.linear_q.bias\n", + "-> encoder.encoders.10.self_attn.linear_k.weight\n", + "encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.10.self_attn.linear_k.bias\n", + "-> encoder.encoders.10.self_attn.linear_v.weight\n", + "encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.10.self_attn.linear_v.bias\n", + "-> encoder.encoders.10.self_attn.linear_out.weight\n", + "encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.10.self_attn.linear_out.bias\n", + "-> encoder.encoders.10.feed_forward.w_1.weight\n", + "encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.10.feed_forward.w_1.bias\n", + "-> encoder.encoders.10.feed_forward.w_2.weight\n", + "encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.10.feed_forward.w_2.bias\n", + "-> encoder.encoders.10.norm1.weight\n", + "-> encoder.encoders.10.norm1.bias\n", + "-> encoder.encoders.10.norm2.weight\n", + "-> encoder.encoders.10.norm2.bias\n", + "-> encoder.encoders.11.self_attn.linear_q.weight\n", + "encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.11.self_attn.linear_q.bias\n", + "-> encoder.encoders.11.self_attn.linear_k.weight\n", + "encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.11.self_attn.linear_k.bias\n", + "-> encoder.encoders.11.self_attn.linear_v.weight\n", + "encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.11.self_attn.linear_v.bias\n", + "-> encoder.encoders.11.self_attn.linear_out.weight\n", + "encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "-> encoder.encoders.11.self_attn.linear_out.bias\n", + "-> encoder.encoders.11.feed_forward.w_1.weight\n", + "encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "-> encoder.encoders.11.feed_forward.w_1.bias\n", + "-> encoder.encoders.11.feed_forward.w_2.weight\n", + "encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "-> encoder.encoders.11.feed_forward.w_2.bias\n", + "-> encoder.encoders.11.norm1.weight\n", + "-> encoder.encoders.11.norm1.bias\n", + "-> encoder.encoders.11.norm2.weight\n", + "-> encoder.encoders.11.norm2.bias\n", + "-> encoder.after_norm.weight\n", + "-> encoder.after_norm.bias\n" + ] + } + ], + "source": [ + "# dump torch model to paddle\n", + "#state_dict = model.state_dict()\n", + "paddle_state_dict = {}\n", + "\n", + "for n, p in state_dict.items():\n", + " if 'encoder' not in n:\n", + " continue \n", + " print(f'-> {n}')\n", + " \n", + " \n", + " name_change=True\n", + " if 'norm.running_mean' in n:\n", + " new_n = n.replace('norm.running_', 'norm._')\n", + " elif 'norm.running_var' in n:\n", + " new_n = n.replace('norm.running_var', 'norm._variance')\n", + " else:\n", + " name_change=False\n", + " new_n = n\n", + " if name_change:\n", + " print(f\"{n} -> {new_n}\")\n", + " \n", + " \n", + " p = p.cpu().detach().numpy()\n", + " if n.endswith('weight') and p.ndim == 2:\n", + " new_p = p.T\n", + " print(f\"{n}: {p.shape} -> {new_p.shape}\")\n", + " else:\n", + " new_p = p\n", + " \n", + " if 'global_cmvn.mean' in n:\n", + " print(p, p.dtype)\n", + " \n", + " paddle_state_dict[new_n] = new_p\n", + " \n", + "# np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n", + "# state=paddle_state_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a1d97e9f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.weight. encoder.encoders.0.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.0.concat_linear.bias. encoder.encoders.0.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.weight. encoder.encoders.1.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.1.concat_linear.bias. encoder.encoders.1.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.weight. encoder.encoders.2.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.2.concat_linear.bias. encoder.encoders.2.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.weight. encoder.encoders.3.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.3.concat_linear.bias. encoder.encoders.3.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.weight. encoder.encoders.4.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.4.concat_linear.bias. encoder.encoders.4.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.weight. encoder.encoders.5.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.5.concat_linear.bias. encoder.encoders.5.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.weight. encoder.encoders.6.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.6.concat_linear.bias. encoder.encoders.6.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.weight. encoder.encoders.7.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.7.concat_linear.bias. encoder.encoders.7.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.weight. encoder.encoders.8.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.8.concat_linear.bias. encoder.encoders.8.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.weight. encoder.encoders.9.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.9.concat_linear.bias. encoder.encoders.9.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.weight. encoder.encoders.10.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.10.concat_linear.bias. encoder.encoders.10.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.weight. encoder.encoders.11.concat_linear.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for encoder.encoders.11.concat_linear.bias. encoder.encoders.11.concat_linear.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.embed.0.weight. decoder.embed.0.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.weight. decoder.after_norm.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.after_norm.bias. decoder.after_norm.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.weight. decoder.output_layer.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.output_layer.bias. decoder.output_layer.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.weight. decoder.decoders.0.self_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_q.bias. decoder.decoders.0.self_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.weight. decoder.decoders.0.self_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_k.bias. decoder.decoders.0.self_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.weight. decoder.decoders.0.self_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_v.bias. decoder.decoders.0.self_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.weight. decoder.decoders.0.self_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.self_attn.linear_out.bias. decoder.decoders.0.self_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.weight. decoder.decoders.0.src_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_q.bias. decoder.decoders.0.src_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.weight. decoder.decoders.0.src_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_k.bias. decoder.decoders.0.src_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.weight. decoder.decoders.0.src_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_v.bias. decoder.decoders.0.src_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.weight. decoder.decoders.0.src_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.src_attn.linear_out.bias. decoder.decoders.0.src_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.weight. decoder.decoders.0.feed_forward.w_1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_1.bias. decoder.decoders.0.feed_forward.w_1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.weight. decoder.decoders.0.feed_forward.w_2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.feed_forward.w_2.bias. decoder.decoders.0.feed_forward.w_2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.weight. decoder.decoders.0.norm1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm1.bias. decoder.decoders.0.norm1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.weight. decoder.decoders.0.norm2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm2.bias. decoder.decoders.0.norm2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.weight. decoder.decoders.0.norm3.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.norm3.bias. decoder.decoders.0.norm3.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.weight. decoder.decoders.0.concat_linear1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear1.bias. decoder.decoders.0.concat_linear1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.weight. decoder.decoders.0.concat_linear2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.0.concat_linear2.bias. decoder.decoders.0.concat_linear2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.weight. decoder.decoders.1.self_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_q.bias. decoder.decoders.1.self_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.weight. decoder.decoders.1.self_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_k.bias. decoder.decoders.1.self_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.weight. decoder.decoders.1.self_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_v.bias. decoder.decoders.1.self_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.weight. decoder.decoders.1.self_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.self_attn.linear_out.bias. decoder.decoders.1.self_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.weight. decoder.decoders.1.src_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_q.bias. decoder.decoders.1.src_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.weight. decoder.decoders.1.src_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_k.bias. decoder.decoders.1.src_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.weight. decoder.decoders.1.src_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_v.bias. decoder.decoders.1.src_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.weight. decoder.decoders.1.src_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.src_attn.linear_out.bias. decoder.decoders.1.src_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.weight. decoder.decoders.1.feed_forward.w_1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_1.bias. decoder.decoders.1.feed_forward.w_1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.weight. decoder.decoders.1.feed_forward.w_2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.feed_forward.w_2.bias. decoder.decoders.1.feed_forward.w_2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.weight. decoder.decoders.1.norm1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm1.bias. decoder.decoders.1.norm1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.weight. decoder.decoders.1.norm2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm2.bias. decoder.decoders.1.norm2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.weight. decoder.decoders.1.norm3.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.norm3.bias. decoder.decoders.1.norm3.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.weight. decoder.decoders.1.concat_linear1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear1.bias. decoder.decoders.1.concat_linear1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.weight. decoder.decoders.1.concat_linear2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.1.concat_linear2.bias. decoder.decoders.1.concat_linear2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.weight. decoder.decoders.2.self_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_q.bias. decoder.decoders.2.self_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.weight. decoder.decoders.2.self_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_k.bias. decoder.decoders.2.self_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.weight. decoder.decoders.2.self_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_v.bias. decoder.decoders.2.self_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.weight. decoder.decoders.2.self_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.self_attn.linear_out.bias. decoder.decoders.2.self_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.weight. decoder.decoders.2.src_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_q.bias. decoder.decoders.2.src_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.weight. decoder.decoders.2.src_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_k.bias. decoder.decoders.2.src_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.weight. decoder.decoders.2.src_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_v.bias. decoder.decoders.2.src_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.weight. decoder.decoders.2.src_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.src_attn.linear_out.bias. decoder.decoders.2.src_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.weight. decoder.decoders.2.feed_forward.w_1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_1.bias. decoder.decoders.2.feed_forward.w_1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.weight. decoder.decoders.2.feed_forward.w_2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.feed_forward.w_2.bias. decoder.decoders.2.feed_forward.w_2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.weight. decoder.decoders.2.norm1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm1.bias. decoder.decoders.2.norm1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.weight. decoder.decoders.2.norm2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm2.bias. decoder.decoders.2.norm2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.weight. decoder.decoders.2.norm3.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.norm3.bias. decoder.decoders.2.norm3.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.weight. decoder.decoders.2.concat_linear1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear1.bias. decoder.decoders.2.concat_linear1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.weight. decoder.decoders.2.concat_linear2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.2.concat_linear2.bias. decoder.decoders.2.concat_linear2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.weight. decoder.decoders.3.self_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_q.bias. decoder.decoders.3.self_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.weight. decoder.decoders.3.self_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_k.bias. decoder.decoders.3.self_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.weight. decoder.decoders.3.self_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_v.bias. decoder.decoders.3.self_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.weight. decoder.decoders.3.self_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.self_attn.linear_out.bias. decoder.decoders.3.self_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.weight. decoder.decoders.3.src_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_q.bias. decoder.decoders.3.src_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.weight. decoder.decoders.3.src_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_k.bias. decoder.decoders.3.src_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.weight. decoder.decoders.3.src_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_v.bias. decoder.decoders.3.src_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.weight. decoder.decoders.3.src_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.src_attn.linear_out.bias. decoder.decoders.3.src_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.weight. decoder.decoders.3.feed_forward.w_1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_1.bias. decoder.decoders.3.feed_forward.w_1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.weight. decoder.decoders.3.feed_forward.w_2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.feed_forward.w_2.bias. decoder.decoders.3.feed_forward.w_2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.weight. decoder.decoders.3.norm1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm1.bias. decoder.decoders.3.norm1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.weight. decoder.decoders.3.norm2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm2.bias. decoder.decoders.3.norm2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.weight. decoder.decoders.3.norm3.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.norm3.bias. decoder.decoders.3.norm3.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.weight. decoder.decoders.3.concat_linear1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear1.bias. decoder.decoders.3.concat_linear1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.weight. decoder.decoders.3.concat_linear2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.3.concat_linear2.bias. decoder.decoders.3.concat_linear2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.weight. decoder.decoders.4.self_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_q.bias. decoder.decoders.4.self_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.weight. decoder.decoders.4.self_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_k.bias. decoder.decoders.4.self_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.weight. decoder.decoders.4.self_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_v.bias. decoder.decoders.4.self_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.weight. decoder.decoders.4.self_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.self_attn.linear_out.bias. decoder.decoders.4.self_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.weight. decoder.decoders.4.src_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_q.bias. decoder.decoders.4.src_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.weight. decoder.decoders.4.src_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_k.bias. decoder.decoders.4.src_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.weight. decoder.decoders.4.src_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_v.bias. decoder.decoders.4.src_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.weight. decoder.decoders.4.src_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.src_attn.linear_out.bias. decoder.decoders.4.src_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.weight. decoder.decoders.4.feed_forward.w_1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_1.bias. decoder.decoders.4.feed_forward.w_1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.weight. decoder.decoders.4.feed_forward.w_2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.feed_forward.w_2.bias. decoder.decoders.4.feed_forward.w_2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.weight. decoder.decoders.4.norm1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm1.bias. decoder.decoders.4.norm1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.weight. decoder.decoders.4.norm2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm2.bias. decoder.decoders.4.norm2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.weight. decoder.decoders.4.norm3.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.norm3.bias. decoder.decoders.4.norm3.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.weight. decoder.decoders.4.concat_linear1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear1.bias. decoder.decoders.4.concat_linear1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.weight. decoder.decoders.4.concat_linear2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.4.concat_linear2.bias. decoder.decoders.4.concat_linear2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.weight. decoder.decoders.5.self_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_q.bias. decoder.decoders.5.self_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.weight. decoder.decoders.5.self_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_k.bias. decoder.decoders.5.self_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.weight. decoder.decoders.5.self_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_v.bias. decoder.decoders.5.self_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.weight. decoder.decoders.5.self_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.self_attn.linear_out.bias. decoder.decoders.5.self_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.weight. decoder.decoders.5.src_attn.linear_q.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_q.bias. decoder.decoders.5.src_attn.linear_q.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.weight. decoder.decoders.5.src_attn.linear_k.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_k.bias. decoder.decoders.5.src_attn.linear_k.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.weight. decoder.decoders.5.src_attn.linear_v.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_v.bias. decoder.decoders.5.src_attn.linear_v.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.weight. decoder.decoders.5.src_attn.linear_out.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.src_attn.linear_out.bias. decoder.decoders.5.src_attn.linear_out.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.weight. decoder.decoders.5.feed_forward.w_1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_1.bias. decoder.decoders.5.feed_forward.w_1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.weight. decoder.decoders.5.feed_forward.w_2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.feed_forward.w_2.bias. decoder.decoders.5.feed_forward.w_2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.weight. decoder.decoders.5.norm1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm1.bias. decoder.decoders.5.norm1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.weight. decoder.decoders.5.norm2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm2.bias. decoder.decoders.5.norm2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.weight. decoder.decoders.5.norm3.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.norm3.bias. decoder.decoders.5.norm3.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.weight. decoder.decoders.5.concat_linear1.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear1.bias. decoder.decoders.5.concat_linear1.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.weight. decoder.decoders.5.concat_linear2.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for decoder.decoders.5.concat_linear2.bias. decoder.decoders.5.concat_linear2.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.weight. ctc.ctc_lo.weight is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n", + "/workspace/DeepSpeech-2.x/tools/venv-2p1/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for ctc.ctc_lo.bias. ctc.ctc_lo.bias is not found in the provided dict.\n", + " warnings.warn((\"Skip loading for {}. \".format(key) + str(err)))\n" + ] + } + ], + "source": [ + "model.set_state_dict(paddle_state_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "fc7edf1e", + "metadata": {}, + "outputs": [], + "source": [ + "e_state = model.encoder.state_dict()\n", + "for key, value in e_state.items():\n", + " if 'concat_linear' in key:\n", + " continue\n", + " if not np.allclose(value.numpy(), paddle_state_dict['encoder.' + key]):\n", + " print(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "572097d0", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "748250b7", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91e5deee", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "fleet-despite", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "embed.npz feat.npz l1.npz l11.npz l3.npz l5.npz l7.npz l9.npz\r\n", + "encoder.npz l0.npz l10.npz l2.npz l4.npz l6.npz l8.npz model.npz\r\n" + ] + } + ], + "source": [ + "%ls .notebook/espnet" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "abroad-oracle", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(8, 57, 83)\n", + "(8, 1, 57)\n", + "[57 50 48 38 32 31 28 25]\n" + ] + } + ], + "source": [ + "data = np.load('.notebook/espnet/feat.npz', allow_pickle=True)\n", + "xs=data['xs']\n", + "masks=data['masks']\n", + "print(xs.shape)\n", + "print(masks.shape)\n", + "xs_lens = masks.sum(axis=-1).squeeze()\n", + "print(xs_lens)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "false-instrument", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[8, 13, 256]\n", + "[8, 1, 13]\n" + ] + } + ], + "source": [ + "# ecnoder\n", + "xs = paddle.to_tensor(xs, dtype='float32')\n", + "x_lens = paddle.to_tensor(xs_lens, dtype='int32')\n", + "model.eval()\n", + "encoder_out, encoder_mask = model.encoder(xs, x_lens)\n", + "print(encoder_out.shape)\n", + "print(encoder_mask.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "arctic-proxy", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(8, 13, 256)\n", + "(8, 1, 13)\n", + "False\n", + "False\n", + "True\n", + "True\n" + ] + } + ], + "source": [ + "data = np.load('.notebook/espnet/encoder.npz', allow_pickle=True)\n", + "xs = data['xs']\n", + "masks = data['masks']\n", + "print(xs.shape)\n", + "print(masks.shape)\n", + "print(np.allclose(xs, encoder_out.numpy()))\n", + "print(np.allclose(xs, encoder_out.numpy(), atol=1e-6))\n", + "print(np.allclose(xs, encoder_out.numpy(), atol=1e-5))\n", + "print(np.allclose(masks, encoder_mask.numpy()))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "seasonal-switch", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 2.1380312 1.8675405 -1.1873871 ... -0.30456656 0.56382173\n", + " -0.6526459 ]\n", + " [ 2.1926146 2.1373641 -0.6548196 ... -0.897318 0.6044322\n", + " -0.63332295]\n", + " [ 1.6367635 2.3320658 -0.8848577 ... -0.9640939 1.2420733\n", + " -0.05243584]\n", + " ...\n", + " [ 1.8533031 1.8421621 -0.6728406 ... 0.04810616 0.6459763\n", + " -0.18188554]\n", + " [ 2.0894065 1.7813934 -1.1591585 ... -0.09513803 0.8321831\n", + " -0.72916794]\n", + " [ 1.6488649 2.0984242 -1.3490562 ... 0.42678255 0.5903866\n", + " -0.32597935]]\n", + "Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [[ 2.13803196, 1.86753929, -1.18738675, ..., -0.30456796, 0.56382364, -0.65264463],\n", + " [ 2.19261336, 2.13736486, -0.65482187, ..., -0.89731705, 0.60443199, -0.63332343],\n", + " [ 1.63676369, 2.33206534, -0.88485885, ..., -0.96409231, 1.24207270, -0.05243752],\n", + " ...,\n", + " [ 1.85330284, 1.84216177, -0.67284071, ..., 0.04810715, 0.64597648, -0.18188696],\n", + " [ 2.08940673, 1.78139246, -1.15916038, ..., -0.09513779, 0.83218288, -0.72916913],\n", + " [ 1.64886570, 2.09842515, -1.34905660, ..., 0.42678308, 0.59038705, -0.32598034]])\n" + ] + } + ], + "source": [ + "print(xs[0])\n", + "print(encoder_out[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "defined-brooks", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 2.209824 1.5208759 0.1417884 ... -0.73617566 1.6538682\n", + " -0.16355833]\n", + " [ 2.1441019 1.4377339 0.3629197 ... -0.91226125 1.3739952\n", + " 0.11874156]\n", + " [ 1.8725398 1.5417286 0.38919652 ... -0.89621615 1.1841662\n", + " 0.27621832]\n", + " ...\n", + " [ 2.4591084 0.7238764 -1.1456345 ... -0.24188249 0.8232168\n", + " -0.9794884 ]\n", + " [ 2.5156236 1.1919155 -0.97032744 ... -0.7360675 1.0647209\n", + " -1.3076135 ]\n", + " [ 2.160009 0.98425585 -1.2231126 ... -0.03393313 1.9141548\n", + " -1.0099151 ]]\n", + "Tensor(shape=[13, 256], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", + " [[ 2.20982409, 1.52087593, 0.14178854, ..., -0.73617446, 1.65386844, -0.16355731],\n", + " [ 2.14410043, 1.43773460, 0.36291891, ..., -0.91226172, 1.37399518, 0.11874183],\n", + " [ 1.87254059, 1.54172909, 0.38919681, ..., -0.89621687, 1.18416822, 0.27621880],\n", + " ...,\n", + " [ 2.45910931, 0.72387671, -1.14563596, ..., -0.24188218, 0.82321703, -0.97948682],\n", + " [ 2.51562238, 1.19191694, -0.97032893, ..., -0.73606837, 1.06472087, -1.30761123],\n", + " [ 2.16000915, 0.98425680, -1.22311163, ..., -0.03393326, 1.91415381, -1.00991392]])\n" + ] + } + ], + "source": [ + "print(xs[1])\n", + "print(encoder_out[1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0504e3f8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/.notebook/wenet_model.ipynb b/.notebook/wenet_model.ipynb new file mode 100644 index 00000000..8e10b6c4 --- /dev/null +++ b/.notebook/wenet_model.ipynb @@ -0,0 +1,5015 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "cfb832c0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/wenet\n" + ] + }, + { + "data": { + "text/plain": [ + "'/workspace/wenet'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%cd /workspace/wenet/\n", + "%pwd" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "62277538", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import argparse\n", + "import copy\n", + "import logging\n", + "import os\n", + "\n", + "import torch\n", + "import torch.distributed as dist\n", + "import torch.optim as optim\n", + "import yaml\n", + "from tensorboardX import SummaryWriter\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from wenet.dataset.dataset import AudioDataset, CollateFunc\n", + "from wenet.transformer.asr_model import init_asr_model\n", + "from wenet.utils.checkpoint import load_checkpoint, save_checkpoint\n", + "from wenet.utils.executor import Executor\n", + "from wenet.utils.scheduler import WarmupLR\n", + "\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = \"0\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2f6ea33a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'config': 'examples/aishell/s0/conf/train_conformer.yaml', 'train_data': 'examples/aishell/s0/raw_wav/train/format.data', 'cv_data': 'examples/aishell/s0/raw_wav/dev/format.data', 'gpu': -1, 'model_dir': None, 'checkpoint': None, 'tensorboard_dir': 'tensorboard', 'rank': 0, 'world_size': -1, 'dist_backend': 'nccl', 'init_method': None, 'num_workers': 0, 'pin_memory': False, 'cmvn': 'examples/aishell/s0/raw_wav/train/global_cmvn'}\n" + ] + } + ], + "source": [ + "parser = argparse.ArgumentParser(description='training your network')\n", + "parser.add_argument('--config', default=\"examples/aishell/s0/conf/train_conformer.yaml\", help='config file')\n", + "parser.add_argument('--train_data', default=\"examples/aishell/s0/raw_wav/train/format.data\", help='train data file')\n", + "parser.add_argument('--cv_data', default=\"examples/aishell/s0/raw_wav/dev/format.data\", help='cv data file')\n", + "parser.add_argument('--gpu',\n", + " type=int,\n", + " default=-1,\n", + " help='gpu id for this local rank, -1 for cpu')\n", + "parser.add_argument('--model_dir' , help='save model dir')\n", + "parser.add_argument('--checkpoint', help='checkpoint model')\n", + "parser.add_argument('--tensorboard_dir',\n", + " default='tensorboard',\n", + " help='tensorboard log dir')\n", + "parser.add_argument('--ddp.rank',\n", + " dest='rank',\n", + " default=0,\n", + " type=int,\n", + " help='global rank for distributed training')\n", + "parser.add_argument('--ddp.world_size',\n", + " dest='world_size',\n", + " default=-1,\n", + " type=int,\n", + " help='''number of total processes/gpus for\n", + " distributed training''')\n", + "parser.add_argument('--ddp.dist_backend',\n", + " dest='dist_backend',\n", + " default='nccl',\n", + " choices=['nccl', 'gloo'],\n", + " help='distributed backend')\n", + "parser.add_argument('--ddp.init_method',\n", + " dest='init_method',\n", + " default=None,\n", + " help='ddp init method')\n", + "parser.add_argument('--num_workers',\n", + " default=0,\n", + " type=int,\n", + " help='num of subprocess workers for reading')\n", + "parser.add_argument('--pin_memory',\n", + " action='store_true',\n", + " default=False,\n", + " help='Use pinned memory buffers used for reading')\n", + "parser.add_argument('--cmvn', default=\"examples/aishell/s0/raw_wav/train/global_cmvn\", help='global cmvn file')\n", + "\n", + "args = parser.parse_args([])\n", + "print(vars(args))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f5d6af9b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Namespace(checkpoint=None, cmvn='examples/aishell/s0/raw_wav/train/global_cmvn', config='examples/aishell/s0/conf/train_conformer.yaml', cv_data='examples/aishell/s0/raw_wav/dev/format.data', dist_backend='nccl', gpu=-1, init_method=None, model_dir=None, num_workers=0, pin_memory=False, rank=0, tensorboard_dir='tensorboard', train_data='examples/aishell/s0/raw_wav/train/format.data', world_size=-1)\n" + ] + } + ], + "source": [ + "# Set random seed\n", + "torch.manual_seed(777)\n", + "print(args)\n", + "with open(args.config, 'r') as fin:\n", + " configs = yaml.load(fin, Loader=yaml.FullLoader)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "264bd353", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7507 batches\n", + "896\n" + ] + } + ], + "source": [ + "raw_wav = configs['raw_wav']\n", + "\n", + "train_collate_func = CollateFunc(**configs['collate_conf'],\n", + " raw_wav=raw_wav)\n", + "\n", + "cv_collate_conf = copy.deepcopy(configs['collate_conf'])\n", + "# no augmenation on cv set\n", + "cv_collate_conf['spec_aug'] = False\n", + "cv_collate_conf['spec_sub'] = False\n", + "if raw_wav:\n", + " cv_collate_conf['feature_dither'] = 0.0\n", + " cv_collate_conf['speed_perturb'] = False\n", + " cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0\n", + "cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav)\n", + "\n", + "dataset_conf = configs.get('dataset_conf', {})\n", + "train_dataset = AudioDataset(args.train_data,\n", + " **dataset_conf,\n", + " raw_wav=raw_wav)\n", + "cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav)\n", + "# 120098 data/train/wav.scp\n", + "print(len(train_dataset), 'batches')\n", + "# 14326 data/dev/wav.scp\n", + "print(len(cv_dataset))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "88863d3c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "896\n" + ] + } + ], + "source": [ + "train_sampler = None\n", + "cv_sampler = None\n", + "train_data_loader = DataLoader(train_dataset,\n", + " collate_fn=train_collate_func,\n", + " sampler=train_sampler,\n", + " #shuffle=(train_sampler is None),\n", + " shuffle=False,\n", + " pin_memory=args.pin_memory,\n", + " batch_size=1,\n", + " num_workers=args.num_workers)\n", + "cv_data_loader = DataLoader(cv_dataset,\n", + " collate_fn=cv_collate_func,\n", + " sampler=cv_sampler,\n", + " shuffle=False,\n", + " batch_size=1,\n", + " pin_memory=args.pin_memory,\n", + " num_workers=args.num_workers)\n", + "print(len(cv_data_loader))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "10d5acd4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4233 vocab\n", + "80 feat dim\n" + ] + } + ], + "source": [ + "if raw_wav:\n", + " input_dim = configs['collate_conf']['feature_extraction_conf'][\n", + " 'mel_bins']\n", + "else:\n", + " input_dim = train_dataset.input_dim\n", + "vocab_size = train_dataset.output_dim\n", + "print(vocab_size, 'vocab')\n", + "print(input_dim , 'feat dim')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0380ef5a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "examples/aishell/s0/raw_wav/train/global_cmvn\n" + ] + } + ], + "source": [ + "# Save configs to model_dir/train.yaml for inference and export\n", + "configs['input_dim'] = input_dim\n", + "configs['output_dim'] = vocab_size\n", + "configs['cmvn_file'] = args.cmvn\n", + "configs['is_json_cmvn'] = raw_wav\n", + "print(args.cmvn)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "15ebf2bf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(80,)\n", + "(80,)\n", + "[ 9.87176362 9.93891555 10.23818678 10.85971412 11.68652649 12.2548801\n", + " 12.65768161 12.86138996 12.80733912 12.56625574 12.32007066 12.13879205\n", + " 12.31318868 12.55255216 12.61223855 12.56974526 12.38972728 12.14383338\n", + " 12.09285066 11.79395822 11.62259065 11.9263303 11.8154422 11.95122567\n", + " 11.83180553 11.88788759 11.79014437 11.88072035 11.90005711 11.97348142\n", + " 12.00982189 12.00881339 12.02619706 12.10479646 12.21555081 12.34399304\n", + " 12.45014401 12.4966879 12.48653775 12.3550783 12.39291732 12.2553737\n", + " 12.26496277 12.25314244 12.32545763 12.43359839 12.54867439 12.6763342\n", + " 12.80920698 12.92934681 12.96115138 12.96883353 12.99593057 13.04728142\n", + " 13.0588804 13.05737948 12.99921175 12.93402238 12.87429219 12.71652995\n", + " 12.48942004 12.27478385 12.26163069 12.28631891 12.31956049 12.4229073\n", + " 12.51480191 12.5785164 12.64719411 12.73762568 12.80017069 12.86872766\n", + " 12.96666856 13.06478583 13.15915908 13.27284306 13.31081821 13.23904279\n", + " 12.87936075 11.18310185]\n", + "[0.61219383 0.49700994 0.33439025 0.31503119 0.29640823 0.28411759\n", + " 0.26972922 0.25610475 0.24632936 0.24610228 0.24733299 0.24426536\n", + " 0.23751781 0.22987273 0.22659963 0.2268427 0.23059031 0.23420722\n", + " 0.23771761 0.2411352 0.24404673 0.24557175 0.24724932 0.25055198\n", + " 0.25482755 0.2602407 0.26363878 0.26503898 0.2648467 0.26435072\n", + " 0.26353625 0.26364794 0.26411054 0.26339948 0.26212082 0.26146597\n", + " 0.26196556 0.26365859 0.26592959 0.26963884 0.27392766 0.27818809\n", + " 0.28313664 0.2863325 0.28713431 0.28649323 0.28636648 0.2867843\n", + " 0.28635904 0.28562022 0.28492711 0.28429201 0.28402977 0.28401045\n", + " 0.28560797 0.28728033 0.28969549 0.29351627 0.29826453 0.30572631\n", + " 0.31811682 0.32887739 0.33288219 0.33326245 0.33014147 0.32403202\n", + " 0.31903576 0.31316258 0.30741037 0.30370692 0.30204833 0.30049064\n", + " 0.29901079 0.29824511 0.29812308 0.29753329 0.29779342 0.30175296\n", + " 0.30955538 0.32904205]\n" + ] + } + ], + "source": [ + "import json\n", + "import math\n", + "import numpy as np\n", + "def _load_json_cmvn(json_cmvn_file):\n", + " \"\"\" Load the json format cmvn stats file and calculate cmvn\n", + "\n", + " Args:\n", + " json_cmvn_file: cmvn stats file in json format\n", + "\n", + " Returns:\n", + " a numpy array of [means, vars]\n", + " \"\"\"\n", + " with open(json_cmvn_file) as f:\n", + " cmvn_stats = json.load(f)\n", + "\n", + " means = cmvn_stats['mean_stat']\n", + " variance = cmvn_stats['var_stat']\n", + " count = cmvn_stats['frame_num']\n", + " for i in range(len(means)):\n", + " means[i] /= count\n", + " variance[i] = variance[i] / count - means[i] * means[i]\n", + " if variance[i] < 1.0e-20:\n", + " variance[i] = 1.0e-20\n", + " variance[i] = 1.0 / math.sqrt(variance[i])\n", + " cmvn = np.array([means, variance])\n", + " return cmvn\n", + "\n", + "\n", + "def _load_kaldi_cmvn(kaldi_cmvn_file):\n", + " \"\"\" Load the kaldi format cmvn stats file and calculate cmvn\n", + "\n", + " Args:\n", + " kaldi_cmvn_file: kaldi text style global cmvn file, which\n", + " is generated by:\n", + " compute-cmvn-stats --binary=false scp:feats.scp global_cmvn\n", + "\n", + " Returns:\n", + " a numpy array of [means, vars]\n", + " \"\"\"\n", + " means = []\n", + " variance = []\n", + " with open(kaldi_cmvn_file, 'r') as fid:\n", + " # kaldi binary file start with '\\0B'\n", + " if fid.read(2) == '\\0B':\n", + " logger.error('kaldi cmvn binary file is not supported, please '\n", + " 'recompute it by: compute-cmvn-stats --binary=false '\n", + " ' scp:feats.scp global_cmvn')\n", + " sys.exit(1)\n", + " fid.seek(0)\n", + " arr = fid.read().split()\n", + " assert (arr[0] == '[')\n", + " assert (arr[-2] == '0')\n", + " assert (arr[-1] == ']')\n", + " feat_dim = int((len(arr) - 2 - 2) / 2)\n", + " for i in range(1, feat_dim + 1):\n", + " means.append(float(arr[i]))\n", + " count = float(arr[feat_dim + 1])\n", + " for i in range(feat_dim + 2, 2 * feat_dim + 2):\n", + " variance.append(float(arr[i]))\n", + "\n", + " for i in range(len(means)):\n", + " means[i] /= count\n", + " variance[i] = variance[i] / count - means[i] * means[i]\n", + " if variance[i] < 1.0e-20:\n", + " variance[i] = 1.0e-20\n", + " variance[i] = 1.0 / math.sqrt(variance[i])\n", + " cmvn = np.array([means, variance])\n", + " return cmvn\n", + "\n", + "\n", + "def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):\n", + " npzfile = np.load(npz_cmvn_file)\n", + " means = npzfile[\"mean\"] #(1, D)\n", + " std = npzfile[\"std\"] #(1, D)\n", + " std = np.clip(std, eps, None)\n", + " variance = 1.0 / std\n", + " cmvn = np.array([means, variance])\n", + " return cmvn\n", + "\n", + "\n", + "def load_cmvn(cmvn_file: str, filetype: str):\n", + " \"\"\"load cmvn from file.\n", + "\n", + " Args:\n", + " cmvn_file (str): cmvn path.\n", + " filetype (str): file type, optional[npz, json, kaldi].\n", + "\n", + " Raises:\n", + " ValueError: file type not support.\n", + "\n", + " Returns:\n", + " Tuple[np.ndarray, np.ndarray]: mean, istd\n", + " \"\"\"\n", + " assert filetype in ['npz', 'json', 'kaldi'], filetype\n", + " filetype = filetype.lower()\n", + " if filetype == \"json\":\n", + " cmvn = _load_json_cmvn(cmvn_file)\n", + " elif filetype == \"kaldi\":\n", + " cmvn = _load_kaldi_cmvn(cmvn_file)\n", + " elif filetype == \"npz\":\n", + " cmvn = _load_npz_cmvn(cmvn_file)\n", + " else:\n", + " raise ValueError(f\"cmvn file type no support: {filetype}\")\n", + " return cmvn[0], cmvn[1]\n", + "\n", + "mean, istd = load_cmvn(args.cmvn, 'json')\n", + "print(mean.shape)\n", + "print(istd.shape)\n", + "print(mean)\n", + "print(istd)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3cfa5e23", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ASRModel(\n", + " (encoder): ConformerEncoder(\n", + " (global_cmvn): GlobalCMVN()\n", + " (embed): Conv2dSubsampling4(\n", + " (conv): Sequential(\n", + " (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2))\n", + " (1): ReLU()\n", + " (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))\n", + " (3): ReLU()\n", + " )\n", + " (out): Sequential(\n", + " (0): Linear(in_features=4864, out_features=256, bias=True)\n", + " )\n", + " (pos_enc): RelPositionalEncoding(\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (encoders): ModuleList(\n", + " (0): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (1): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (2): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (3): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (4): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (5): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (6): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (7): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (8): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (9): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (10): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (11): ConformerEncoderLayer(\n", + " (self_attn): RelPositionMultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (linear_pos): Linear(in_features=256, out_features=256, bias=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (feed_forward_macaron): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (conv_module): ConvolutionModule(\n", + " (pointwise_conv1): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): Conv1d(256, 256, kernel_size=(15,), stride=(1,), padding=(7,), groups=256)\n", + " (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pointwise_conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,))\n", + " (activation): Swish()\n", + " )\n", + " (norm_ff): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_mha): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_ff_macaron): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_conv): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm_final): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (decoder): TransformerDecoder(\n", + " (embed): Sequential(\n", + " (0): Embedding(4233, 256)\n", + " (1): PositionalEncoding(\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (output_layer): Linear(in_features=256, out_features=4233, bias=True)\n", + " (decoders): ModuleList(\n", + " (0): DecoderLayer(\n", + " (self_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (src_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): ReLU()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", + " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (1): DecoderLayer(\n", + " (self_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (src_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): ReLU()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", + " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (2): DecoderLayer(\n", + " (self_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (src_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): ReLU()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", + " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (3): DecoderLayer(\n", + " (self_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (src_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): ReLU()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", + " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (4): DecoderLayer(\n", + " (self_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (src_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): ReLU()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", + " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " (5): DecoderLayer(\n", + " (self_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (src_attn): MultiHeadedAttention(\n", + " (linear_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_k): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_v): Linear(in_features=256, out_features=256, bias=True)\n", + " (linear_out): Linear(in_features=256, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (feed_forward): PositionwiseFeedForward(\n", + " (w_1): Linear(in_features=256, out_features=2048, bias=True)\n", + " (activation): ReLU()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (w_2): Linear(in_features=2048, out_features=256, bias=True)\n", + " )\n", + " (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (norm3): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (concat_linear1): Linear(in_features=512, out_features=256, bias=True)\n", + " (concat_linear2): Linear(in_features=512, out_features=256, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (ctc): CTC(\n", + " (ctc_lo): Linear(in_features=256, out_features=4233, bias=True)\n", + " (ctc_loss): CTCLoss()\n", + " )\n", + " (criterion_att): LabelSmoothingLoss(\n", + " (criterion): KLDivLoss()\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "# Init asr model from configs\n", + "model = init_asr_model(configs)\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "3c780af5", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def summary(layer, print_func=print):\n", + " num_params = num_elements = 0\n", + " for name, param in layer.state_dict().items():\n", + " if print_func:\n", + " print_func(\n", + " \"{} | {} | {}\".format(name, param.shape, np.prod(param.shape)))\n", + " num_elements += np.prod(param.shape)\n", + " num_params += 1\n", + " if print_func:\n", + " print_func(\n", + " f\"Total parameters: {num_params}, {num_elements} elements.\"\n", + " )\n", + " \n", + "def print_params(model, print_func=print):\n", + " if print_func is None:\n", + " return\n", + " total = 0.0\n", + " num_params = 0.0\n", + " for n, p in model.named_parameters():\n", + " msg = f\"{n} | {p.shape} | {np.prod(p.shape)} | {p.requires_grad}\"\n", + " total += np.prod(p.shape)\n", + " num_params += 1\n", + " if print_func:\n", + " print_func(msg)\n", + " if print_func:\n", + " print_func(f\"Total parameters: {num_params}, {total} elements.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "e159a200", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "encoder.global_cmvn.mean | torch.Size([80]) | 80\n", + "encoder.global_cmvn.istd | torch.Size([80]) | 80\n", + "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304\n", + "encoder.embed.conv.0.bias | torch.Size([256]) | 256\n", + "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824\n", + "encoder.embed.conv.2.bias | torch.Size([256]) | 256\n", + "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184\n", + "encoder.embed.out.0.bias | torch.Size([256]) | 256\n", + "encoder.after_norm.weight | torch.Size([256]) | 256\n", + "encoder.after_norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.0.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.0.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.1.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.1.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.2.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.2.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.3.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.3.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.4.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.4.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.5.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.5.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.6.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.6.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.7.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.7.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.8.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.8.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.9.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.9.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.10.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.10.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256\n", + "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256\n", + "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536\n", + "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048\n", + "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072\n", + "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512\n", + "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840\n", + "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256\n", + "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.conv_module.norm.running_mean | torch.Size([256]) | 256\n", + "encoder.encoders.11.conv_module.norm.running_var | torch.Size([256]) | 256\n", + "encoder.encoders.11.conv_module.norm.num_batches_tracked | torch.Size([]) | 1.0\n", + "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536\n", + "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256\n", + "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256\n", + "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256\n", + "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256\n", + "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256\n", + "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256\n", + "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072\n", + "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256\n", + "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648\n", + "decoder.after_norm.weight | torch.Size([256]) | 256\n", + "decoder.after_norm.bias | torch.Size([256]) | 256\n", + "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648\n", + "decoder.output_layer.bias | torch.Size([4233]) | 4233\n", + "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256\n", + "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256\n", + "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256\n", + "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256\n", + "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256\n", + "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256\n", + "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256\n", + "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256\n", + "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256\n", + "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256\n", + "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256\n", + "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256\n", + "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256\n", + "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256\n", + "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256\n", + "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536\n", + "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288\n", + "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048\n", + "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288\n", + "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256\n", + "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256\n", + "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256\n", + "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256\n", + "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072\n", + "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256\n", + "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648\n", + "ctc.ctc_lo.bias | torch.Size([4233]) | 4233\n", + "Total parameters: 701, 49355454.0 elements.\n" + ] + } + ], + "source": [ + "summary(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8494c6ab", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0648a969", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304 | True\n", + "encoder.embed.conv.0.bias | torch.Size([256]) | 256 | True\n", + "encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824 | True\n", + "encoder.embed.conv.2.bias | torch.Size([256]) | 256 | True\n", + "encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184 | True\n", + "encoder.embed.out.0.bias | torch.Size([256]) | 256 | True\n", + "encoder.after_norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.after_norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.0.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.0.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.0.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.0.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.0.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.0.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.0.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.0.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.0.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.1.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.1.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.1.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.1.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.1.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.1.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.1.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.1.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.1.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.1.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.2.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.2.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.2.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.2.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.2.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.2.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.2.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.2.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.2.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.2.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.3.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.3.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.3.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.3.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.3.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.3.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.3.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.3.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.3.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.3.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.4.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.4.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.4.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.4.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.4.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.4.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.4.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.4.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.4.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.4.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.5.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.5.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.5.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.5.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.5.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.5.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.5.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.5.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.5.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.5.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.6.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.6.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.6.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.6.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.6.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.6.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.6.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.6.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.6.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.6.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.6.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.6.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.6.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.6.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.6.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.6.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.6.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.6.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.6.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.7.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.7.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.7.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.7.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.7.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.7.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.7.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.7.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.7.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.7.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.7.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.7.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.7.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.7.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.7.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.7.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.7.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.7.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.7.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.8.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.8.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.8.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.8.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.8.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.8.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.8.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.8.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.8.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.8.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.8.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.8.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.8.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.8.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.8.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.8.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.8.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.8.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.8.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.9.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.9.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.9.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.9.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.9.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.9.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.9.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.9.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.9.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.9.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.9.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.9.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.9.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.9.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.9.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.9.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.9.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.9.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.9.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.10.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.10.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.10.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.10.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.10.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.10.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.10.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.10.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.10.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.10.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.10.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.10.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.10.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.10.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.10.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.10.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.10.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.10.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.10.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.11.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True\n", + "encoder.encoders.11.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.11.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.11.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.11.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.11.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.self_attn.linear_pos.weight | torch.Size([256, 256]) | 65536 | True\n", + "encoder.encoders.11.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.11.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.11.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.11.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.feed_forward_macaron.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "encoder.encoders.11.feed_forward_macaron.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "encoder.encoders.11.feed_forward_macaron.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "encoder.encoders.11.feed_forward_macaron.w_2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.conv_module.pointwise_conv1.weight | torch.Size([512, 256, 1]) | 131072 | True\n", + "encoder.encoders.11.conv_module.pointwise_conv1.bias | torch.Size([512]) | 512 | True\n", + "encoder.encoders.11.conv_module.depthwise_conv.weight | torch.Size([256, 1, 15]) | 3840 | True\n", + "encoder.encoders.11.conv_module.depthwise_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.conv_module.norm.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.conv_module.norm.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.conv_module.pointwise_conv2.weight | torch.Size([256, 256, 1]) | 65536 | True\n", + "encoder.encoders.11.conv_module.pointwise_conv2.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.norm_ff.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.norm_ff.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.norm_mha.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.norm_mha.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.norm_ff_macaron.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.norm_ff_macaron.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.norm_conv.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.norm_conv.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.norm_final.weight | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.norm_final.bias | torch.Size([256]) | 256 | True\n", + "encoder.encoders.11.concat_linear.weight | torch.Size([256, 512]) | 131072 | True\n", + "encoder.encoders.11.concat_linear.bias | torch.Size([256]) | 256 | True\n", + "decoder.embed.0.weight | torch.Size([4233, 256]) | 1083648 | True\n", + "decoder.after_norm.weight | torch.Size([256]) | 256 | True\n", + "decoder.after_norm.bias | torch.Size([256]) | 256 | True\n", + "decoder.output_layer.weight | torch.Size([4233, 256]) | 1083648 | True\n", + "decoder.output_layer.bias | torch.Size([4233]) | 4233 | True\n", + "decoder.decoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.0.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.0.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.0.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.0.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.0.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.0.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "decoder.decoders.0.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "decoder.decoders.0.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "decoder.decoders.0.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.norm1.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.norm1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.norm2.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.norm2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.norm3.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.norm3.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.0.concat_linear1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.0.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.0.concat_linear2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.1.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.1.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.1.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.1.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.1.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.1.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.1.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.1.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "decoder.decoders.1.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "decoder.decoders.1.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "decoder.decoders.1.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.norm1.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.norm1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.norm2.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.norm2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.norm3.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.norm3.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.1.concat_linear1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.1.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.1.concat_linear2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.2.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.2.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.2.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.2.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.2.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.2.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.2.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.2.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "decoder.decoders.2.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "decoder.decoders.2.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "decoder.decoders.2.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.norm1.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.norm1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.norm2.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.norm2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.norm3.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.norm3.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.2.concat_linear1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.2.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.2.concat_linear2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.3.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.3.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.3.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.3.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.3.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.3.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.3.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.3.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "decoder.decoders.3.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "decoder.decoders.3.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "decoder.decoders.3.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.norm1.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.norm1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.norm2.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.norm2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.norm3.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.norm3.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.3.concat_linear1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.3.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.3.concat_linear2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.4.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.4.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.4.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.4.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.4.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.4.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.4.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.4.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "decoder.decoders.4.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "decoder.decoders.4.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "decoder.decoders.4.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.norm1.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.norm1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.norm2.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.norm2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.norm3.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.norm3.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.4.concat_linear1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.4.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.4.concat_linear2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.5.self_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.5.self_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.self_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.5.self_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.self_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.5.self_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.src_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.5.src_attn.linear_q.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.src_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.5.src_attn.linear_k.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.src_attn.linear_v.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.5.src_attn.linear_v.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.src_attn.linear_out.weight | torch.Size([256, 256]) | 65536 | True\n", + "decoder.decoders.5.src_attn.linear_out.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.feed_forward.w_1.weight | torch.Size([2048, 256]) | 524288 | True\n", + "decoder.decoders.5.feed_forward.w_1.bias | torch.Size([2048]) | 2048 | True\n", + "decoder.decoders.5.feed_forward.w_2.weight | torch.Size([256, 2048]) | 524288 | True\n", + "decoder.decoders.5.feed_forward.w_2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.norm1.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.norm1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.norm2.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.norm2.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.norm3.weight | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.norm3.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.concat_linear1.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.5.concat_linear1.bias | torch.Size([256]) | 256 | True\n", + "decoder.decoders.5.concat_linear2.weight | torch.Size([256, 512]) | 131072 | True\n", + "decoder.decoders.5.concat_linear2.bias | torch.Size([256]) | 256 | True\n", + "ctc.ctc_lo.weight | torch.Size([4233, 256]) | 1083648 | True\n", + "ctc.ctc_lo.bias | torch.Size([4233]) | 4233 | True\n", + "Total parameters: 663.0, 49349138.0 elements.\n" + ] + } + ], + "source": [ + "print_params(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "5ad6de2a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']\n", + "torch.Size([16, 207, 80])\n", + "tensor([[[ 8.9946, 9.5383, 9.1916, ..., 10.5074, 9.5633, 8.2564],\n", + " [ 9.7988, 10.4052, 9.2651, ..., 10.2512, 9.5440, 8.8738],\n", + " [10.6891, 10.3955, 8.0535, ..., 9.9067, 10.0649, 8.0509],\n", + " ...,\n", + " [ 9.2180, 9.6507, 8.5053, ..., 9.6872, 8.7425, 7.9865],\n", + " [10.1291, 9.9352, 9.3798, ..., 9.5639, 9.8260, 8.9795],\n", + " [ 9.0955, 7.1338, 9.4680, ..., 9.4727, 9.0212, 7.4479]],\n", + "\n", + " [[11.4310, 10.6719, 6.0841, ..., 9.3827, 8.7297, 7.5316],\n", + " [ 9.7317, 7.8105, 7.5715, ..., 10.0430, 9.2436, 7.3541],\n", + " [10.6502, 10.6006, 8.4678, ..., 9.2814, 9.1869, 8.0703],\n", + " ...,\n", + " [ 9.0970, 9.2637, 8.0753, ..., 8.4318, 8.3705, 8.0029],\n", + " [10.4617, 10.1478, 6.7693, ..., 9.7794, 9.5775, 8.0807],\n", + " [ 7.7944, 5.6211, 7.9751, ..., 9.9972, 9.8497, 8.0313]],\n", + "\n", + " [[ 7.3456, 7.8964, 7.5796, ..., 11.6310, 10.4513, 9.1236],\n", + " [ 8.6287, 8.4631, 7.4992, ..., 12.4160, 10.9757, 8.9426],\n", + " [ 9.8314, 10.2813, 8.9724, ..., 12.1387, 10.4017, 9.0055],\n", + " ...,\n", + " [ 7.0896, 7.4055, 6.8143, ..., 9.3252, 9.2732, 8.3534],\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", + "\n", + " ...,\n", + "\n", + " [[10.9332, 10.4644, 7.7203, ..., 10.3488, 9.3023, 7.1553],\n", + " [10.4499, 9.9070, 9.0293, ..., 9.9525, 9.4141, 7.5593],\n", + " [10.4877, 9.8126, 9.8952, ..., 9.5866, 9.3413, 7.7849],\n", + " ...,\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", + "\n", + " [[ 9.9444, 9.5859, 8.2203, ..., 11.5886, 11.0450, 8.8171],\n", + " [ 7.6784, 8.3224, 7.5330, ..., 11.0551, 10.5357, 9.2746],\n", + " [ 8.6262, 9.6759, 9.8410, ..., 11.3788, 10.9221, 8.9914],\n", + " ...,\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],\n", + "\n", + " [[ 8.1079, 7.7590, 6.7103, ..., 12.6506, 11.4662, 11.0615],\n", + " [11.3803, 11.2220, 8.6589, ..., 12.8106, 12.2222, 11.6893],\n", + " [10.6777, 9.9206, 8.0461, ..., 13.5729, 12.5624, 11.1550],\n", + " ...,\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])\n", + "tensor([207, 207, 205, 205, 203, 203, 198, 197, 195, 188, 186, 186, 185, 180,\n", + " 166, 163], dtype=torch.int32)\n", + "tensor([[2995, 3116, 1209, 565, -1, -1],\n", + " [ 236, 1176, 331, 66, 3925, 4077],\n", + " [2693, 524, 234, 1145, 366, -1],\n", + " [3875, 4211, 3062, 700, -1, -1],\n", + " [ 272, 987, 1134, 494, 2959, -1],\n", + " [1936, 3715, 120, 2553, 2695, 2710],\n", + " [ 25, 1149, 3930, -1, -1, -1],\n", + " [1753, 1778, 1237, 482, 3925, 110],\n", + " [3703, 2, 565, 3827, -1, -1],\n", + " [1150, 2734, 10, 2478, 3490, -1],\n", + " [ 426, 811, 95, 489, 144, -1],\n", + " [2313, 2006, 489, 975, -1, -1],\n", + " [3702, 3414, 205, 1488, 2966, 1347],\n", + " [ 70, 1741, 702, 1666, -1, -1],\n", + " [ 703, 1778, 1030, 849, -1, -1],\n", + " [ 814, 1674, 115, 3827, -1, -1]], dtype=torch.int32)\n", + "tensor([4, 6, 5, 4, 5, 6, 3, 6, 4, 5, 5, 4, 6, 4, 4, 4], dtype=torch.int32)\n" + ] + } + ], + "source": [ + "for batch in cv_data_loader:\n", + " keys, feat, text, feat_len, text_len = batch\n", + " print(keys)\n", + " print(feat.shape)\n", + " print(feat)\n", + " print(feat_len)\n", + " print(text)\n", + " print(text_len)\n", + " np.savez('data.npz', keys=keys, feat=feat.numpy(), feat_len=feat_len.numpy(), text=text.numpy(), text_len=text_len.numpy())\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "852a9c95", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CODE_OF_CONDUCT.md data.npz install.sh README.md\t tools\r\n", + "CONTRIBUTING.md docs LICENSE\t requirements.txt venv\r\n", + "CPPLINT.cfg\t examples Makefile\t runtime\t wenet\r\n" + ] + } + ], + "source": [ + "!ls\n", + "!cp data.npz /workspace/DeepSpeech-2.x/.notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "cde24c4e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(111.9988)\n", + "tensor(830.9634, grad_fn=)\n", + "tensor([False, False, False, False, False, True, True, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " True, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, False, True, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, True, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, False, False, False, False, False, False, True,\n", + " False, False, False, False, False, False, True, False, False, False,\n", + " False, False, True, True, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, True, True, False, False, False, False, False,\n", + " True, True])\n", + "tensor(669.4633, grad_fn=)\n", + "tensor(142.4888, grad_fn=) tensor(41.8415, grad_fn=) tensor(377.3326, grad_fn=)\n" + ] + } + ], + "source": [ + "model.cpu().eval()\n", + "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", + " text, text_len)\n", + "print(total_loss, attention_loss, ctc_loss )" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "be5b2a2c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cpu\n" + ] + } + ], + "source": [ + "print(total_loss.device)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "5b791771", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(112., device='cuda:0')\n", + "tensor(830.9634, device='cuda:0', grad_fn=)\n", + "tensor([False, False, False, False, False, True, True, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " True, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, False, True, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, True, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, False, False, False, False, False, False, True,\n", + " False, False, False, False, False, False, True, False, False, False,\n", + " False, False, True, True, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, True, True, False, False, False, False, False,\n", + " True, True], device='cuda:0')\n", + "tensor(669.4634, device='cuda:0', grad_fn=)\n", + "cuda:0\n", + "142.4888 41.84146 377.33258\n" + ] + } + ], + "source": [ + "model.cuda().eval()\n", + "feat=feat.cuda()\n", + "feat_len=feat_len.cuda()\n", + "text=text.cuda()\n", + "text_len=text_len.cuda()\n", + "\n", + "total_loss, attention_loss, ctc_loss = model(feat, feat_len,\n", + " text, text_len)\n", + "print(total_loss.device)\n", + "print(total_loss.cpu().data.numpy(), attention_loss.cpu().data.numpy(), ctc_loss.cpu().data.numpy() )" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "1baef537", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([16, 51, 256])\n", + "torch.Size([16, 1, 51])\n", + "tensor([[-0.7019, 0.5625, 0.6880, ..., 1.1237, 0.7804, 1.1369],\n", + " [-0.7788, 0.3913, 0.7189, ..., 1.2519, 0.8862, 1.3173],\n", + " [-0.9591, 0.6346, 0.8767, ..., 0.9818, 0.7440, 1.2903],\n", + " ...,\n", + " [-1.0732, 0.6724, 0.9230, ..., 0.9075, 0.8177, 1.3240],\n", + " [-1.1654, 0.6820, 0.6939, ..., 1.2238, 0.8028, 1.4507],\n", + " [-1.2732, 0.7146, 0.7582, ..., 0.9415, 0.8775, 1.2623]],\n", + " device='cuda:0', grad_fn=)\n" + ] + } + ], + "source": [ + "encoder_out, encoder_mask = model.encoder(feat, feat_len)\n", + "print(encoder_out.shape)\n", + "print(encoder_mask.shape)\n", + "print(encoder_out[0])\n", + "\n", + "np.savez('/workspace/DeepSpeech-2.x/.notebook/encoder.npz',\n", + " mask=encoder_mask.cpu().detach().numpy(), \n", + " out=encoder_out.cpu().detach().numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e22c782", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "30b6b946", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 9.871763 9.938915 10.238187 10.8597145 11.686526 12.25488\n", + " 12.657681 12.86139 12.807339 12.566256 12.32007 12.138792\n", + " 12.313189 12.552552 12.612239 12.569745 12.389728 12.143833\n", + " 12.092851 11.793959 11.622591 11.926331 11.815442 11.951225\n", + " 11.831805 11.887888 11.790144 11.88072 11.900057 11.973481\n", + " 12.009822 12.008814 12.026197 12.104796 12.21555 12.343993\n", + " 12.450144 12.496688 12.486538 12.355079 12.392918 12.255374\n", + " 12.264963 12.253142 12.325458 12.4335985 12.548675 12.676334\n", + " 12.809207 12.929347 12.961151 12.968834 12.995931 13.047281\n", + " 13.058881 13.05738 12.999211 12.934022 12.874292 12.71653\n", + " 12.48942 12.274784 12.261631 12.286319 12.31956 12.422907\n", + " 12.514802 12.578516 12.647194 12.737626 12.800171 12.868728\n", + " 12.966668 13.064786 13.159159 13.272843 13.310819 13.239043\n", + " 12.879361 11.183102 ] float32\n", + "encoder.embed.out.0.weight: (256, 4864) -> (4864, 256)\n", + "encoder.encoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.0.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.0.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.0.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.0.conv_module.norm.running_mean -> encoder.encoders.0.conv_module.norm._mean\n", + "encoder.encoders.0.conv_module.norm.running_var -> encoder.encoders.0.conv_module.norm._variance\n", + "encoder.encoders.0.concat_linear.weight: (256, 512) -> (512, 256)\n", + "encoder.encoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.1.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.1.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.1.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.1.conv_module.norm.running_mean -> encoder.encoders.1.conv_module.norm._mean\n", + "encoder.encoders.1.conv_module.norm.running_var -> encoder.encoders.1.conv_module.norm._variance\n", + "encoder.encoders.1.concat_linear.weight: (256, 512) -> (512, 256)\n", + "encoder.encoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.2.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.2.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.2.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.2.conv_module.norm.running_mean -> encoder.encoders.2.conv_module.norm._mean\n", + "encoder.encoders.2.conv_module.norm.running_var -> encoder.encoders.2.conv_module.norm._variance\n", + "encoder.encoders.2.concat_linear.weight: (256, 512) -> (512, 256)\n", + "encoder.encoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.3.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.3.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.3.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.3.conv_module.norm.running_mean -> encoder.encoders.3.conv_module.norm._mean\n", + "encoder.encoders.3.conv_module.norm.running_var -> encoder.encoders.3.conv_module.norm._variance\n", + "encoder.encoders.3.concat_linear.weight: (256, 512) -> (512, 256)\n", + "encoder.encoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.4.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.4.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.4.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.4.conv_module.norm.running_mean -> encoder.encoders.4.conv_module.norm._mean\n", + "encoder.encoders.4.conv_module.norm.running_var -> encoder.encoders.4.conv_module.norm._variance\n", + "encoder.encoders.4.concat_linear.weight: (256, 512) -> (512, 256)\n", + "encoder.encoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.5.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.5.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.5.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.5.conv_module.norm.running_mean -> encoder.encoders.5.conv_module.norm._mean\n", + "encoder.encoders.5.conv_module.norm.running_var -> encoder.encoders.5.conv_module.norm._variance\n", + "encoder.encoders.5.concat_linear.weight: (256, 512) -> (512, 256)\n", + "encoder.encoders.6.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.6.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.6.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.6.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.6.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.6.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.6.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.6.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.6.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.6.conv_module.norm.running_mean -> encoder.encoders.6.conv_module.norm._mean\n", + "encoder.encoders.6.conv_module.norm.running_var -> encoder.encoders.6.conv_module.norm._variance\n", + "encoder.encoders.6.concat_linear.weight: (256, 512) -> (512, 256)\n", + "encoder.encoders.7.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.7.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.7.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.7.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.7.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.7.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.7.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.7.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.7.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.7.conv_module.norm.running_mean -> encoder.encoders.7.conv_module.norm._mean\n", + "encoder.encoders.7.conv_module.norm.running_var -> encoder.encoders.7.conv_module.norm._variance\n", + "encoder.encoders.7.concat_linear.weight: (256, 512) -> (512, 256)\n", + "encoder.encoders.8.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.8.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.8.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.8.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.8.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.8.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.8.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.8.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.8.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.8.conv_module.norm.running_mean -> encoder.encoders.8.conv_module.norm._mean\n", + "encoder.encoders.8.conv_module.norm.running_var -> encoder.encoders.8.conv_module.norm._variance\n", + "encoder.encoders.8.concat_linear.weight: (256, 512) -> (512, 256)\n", + "encoder.encoders.9.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.9.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.9.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.9.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.9.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.9.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.9.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.9.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.9.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.9.conv_module.norm.running_mean -> encoder.encoders.9.conv_module.norm._mean\n", + "encoder.encoders.9.conv_module.norm.running_var -> encoder.encoders.9.conv_module.norm._variance\n", + "encoder.encoders.9.concat_linear.weight: (256, 512) -> (512, 256)\n", + "encoder.encoders.10.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.10.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.10.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.10.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.10.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.10.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.10.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.10.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.10.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.10.conv_module.norm.running_mean -> encoder.encoders.10.conv_module.norm._mean\n", + "encoder.encoders.10.conv_module.norm.running_var -> encoder.encoders.10.conv_module.norm._variance\n", + "encoder.encoders.10.concat_linear.weight: (256, 512) -> (512, 256)\n", + "encoder.encoders.11.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.11.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.11.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.11.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.11.self_attn.linear_pos.weight: (256, 256) -> (256, 256)\n", + "encoder.encoders.11.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.11.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.11.feed_forward_macaron.w_1.weight: (2048, 256) -> (256, 2048)\n", + "encoder.encoders.11.feed_forward_macaron.w_2.weight: (256, 2048) -> (2048, 256)\n", + "encoder.encoders.11.conv_module.norm.running_mean -> encoder.encoders.11.conv_module.norm._mean\n", + "encoder.encoders.11.conv_module.norm.running_var -> encoder.encoders.11.conv_module.norm._variance\n", + "encoder.encoders.11.concat_linear.weight: (256, 512) -> (512, 256)\n", + "decoder.output_layer.weight: (4233, 256) -> (256, 4233)\n", + "decoder.decoders.0.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.0.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.0.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.0.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.0.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.0.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.0.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.0.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.0.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "decoder.decoders.0.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "decoder.decoders.0.concat_linear1.weight: (256, 512) -> (512, 256)\n", + "decoder.decoders.0.concat_linear2.weight: (256, 512) -> (512, 256)\n", + "decoder.decoders.1.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.1.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.1.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.1.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.1.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.1.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.1.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.1.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.1.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "decoder.decoders.1.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "decoder.decoders.1.concat_linear1.weight: (256, 512) -> (512, 256)\n", + "decoder.decoders.1.concat_linear2.weight: (256, 512) -> (512, 256)\n", + "decoder.decoders.2.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.2.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.2.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.2.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.2.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.2.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.2.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.2.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.2.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "decoder.decoders.2.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "decoder.decoders.2.concat_linear1.weight: (256, 512) -> (512, 256)\n", + "decoder.decoders.2.concat_linear2.weight: (256, 512) -> (512, 256)\n", + "decoder.decoders.3.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.3.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.3.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.3.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.3.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.3.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.3.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.3.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.3.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "decoder.decoders.3.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "decoder.decoders.3.concat_linear1.weight: (256, 512) -> (512, 256)\n", + "decoder.decoders.3.concat_linear2.weight: (256, 512) -> (512, 256)\n", + "decoder.decoders.4.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.4.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.4.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.4.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.4.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.4.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.4.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.4.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.4.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "decoder.decoders.4.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "decoder.decoders.4.concat_linear1.weight: (256, 512) -> (512, 256)\n", + "decoder.decoders.4.concat_linear2.weight: (256, 512) -> (512, 256)\n", + "decoder.decoders.5.self_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.5.self_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.5.self_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.5.self_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.5.src_attn.linear_q.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.5.src_attn.linear_k.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.5.src_attn.linear_v.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.5.src_attn.linear_out.weight: (256, 256) -> (256, 256)\n", + "decoder.decoders.5.feed_forward.w_1.weight: (2048, 256) -> (256, 2048)\n", + "decoder.decoders.5.feed_forward.w_2.weight: (256, 2048) -> (2048, 256)\n", + "decoder.decoders.5.concat_linear1.weight: (256, 512) -> (512, 256)\n", + "decoder.decoders.5.concat_linear2.weight: (256, 512) -> (512, 256)\n", + "ctc.ctc_lo.weight: (4233, 256) -> (256, 4233)\n" + ] + } + ], + "source": [ + "# dump torch model to paddle\n", + "import numpy as np\n", + "state_dict = model.state_dict()\n", + "paddle_state_dict = {}\n", + "\n", + "for n, p in state_dict.items():\n", + " name_change=True\n", + "\n", + " if 'norm.running_mean' in n:\n", + " new_n = n.replace('norm.running_', 'norm._')\n", + " elif 'norm.running_var' in n:\n", + " new_n = n.replace('norm.running_var', 'norm._variance')\n", + " else:\n", + " name_change=False\n", + " new_n = n\n", + " \n", + " if name_change:\n", + " print(f\"{n} -> {new_n}\")\n", + " \n", + " p = p.cpu().detach().numpy()\n", + " if n.endswith('weight') and p.ndim == 2 and 'embed.0.weight' not in n:\n", + " new_p = p.T\n", + " print(f\"{n}: {p.shape} -> {new_p.shape}\")\n", + " else:\n", + " new_p = p\n", + " \n", + " if 'global_cmvn.mean' in n:\n", + " print(p, p.dtype)\n", + " \n", + " paddle_state_dict[new_n] = new_p\n", + " \n", + "np.savez('/workspace/DeepSpeech-2.x/.notebook/model',\n", + " state=paddle_state_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7307dc5b", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "d99b29bc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(377.3326, device='cuda:0', grad_fn=)\n", + "None\n", + "[[ 3.16902351e+00 -1.51765049e-02 4.91097234e-02 ... -2.47973716e-03\n", + " -5.93366381e-03 -7.26613170e-03]\n", + " [-1.74185038e+00 7.75875803e-03 -4.49435972e-02 ... 9.92415240e-04\n", + " 2.46338220e-03 2.31891591e-03]\n", + " [-2.33343077e+00 1.30476682e-02 -2.66557615e-02 ... 2.27533933e-03\n", + " 5.76929189e-03 7.48792710e-03]\n", + " ...\n", + " [-4.30356789e+00 2.46056803e-02 -9.00955945e-02 ... 4.43160534e-03\n", + " 1.16123557e-02 1.44716976e-02]\n", + " [-3.36919212e+00 1.73155665e-02 -6.36875406e-02 ... 3.28367390e-03\n", + " 8.58021621e-03 1.07796099e-02]\n", + " [-6.62039661e+00 3.49958315e-02 -1.23963736e-01 ... 6.36674836e-03\n", + " 1.60815325e-02 2.03892551e-02]]\n", + "[-4.3777566e+00 2.3245990e-02 -9.3339972e-02 ... 4.2569702e-03\n", + " 1.0920014e-02 1.3787906e-02]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":6: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.\n", + " print(loss_ctc.grad)\n" + ] + } + ], + "source": [ + "encoder_out_lens = encoder_mask.squeeze(1).sum(1)\n", + "loss_ctc = model.ctc(encoder_out, encoder_out_lens, text, text_len)\n", + "print(loss_ctc)\n", + "dir(loss_ctc)\n", + "loss_ctc.backward()\n", + "print(loss_ctc.grad)\n", + "#print(model.ctc.ctc_lo.weight.grad)\n", + "print(model.ctc.ctc_lo.weight.grad.T.cpu().data.numpy())\n", + "print(model.ctc.ctc_lo.bias.grad.cpu().data.numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "49b05d6d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(112., device='cuda:0')\n", + "tensor(830.9634, device='cuda:0', grad_fn=)\n", + "tensor([False, False, False, False, False, True, True, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " True, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, False, True, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, True, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, False, False, False, False, False, False, True,\n", + " False, False, False, False, False, False, True, False, False, False,\n", + " False, False, True, True, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, True, True, False, False, False, False, False,\n", + " True, True], device='cuda:0')\n", + "tensor(669.4634, device='cuda:0', grad_fn=)\n", + "tensor(41.8415, device='cuda:0', grad_fn=) 0.0\n" + ] + } + ], + "source": [ + "loss_att, acc_att = model._calc_att_loss(encoder_out, encoder_mask,\n", + " text, text_len)\n", + "print(loss_att, acc_att)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "413b413f", + "metadata": {}, + "outputs": [], + "source": [ + "def pad_list(xs, pad_value: int):\n", + " n_batch = len(xs)\n", + " max_len = max([x.size(0) for x in xs])\n", + " pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)\n", + " pad = pad.fill_(pad_value)\n", + " for i in range(n_batch):\n", + " pad[i, :xs[i].size(0)] = xs[i]\n", + "\n", + " return pad\n", + "\n", + "def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,\n", + " ignore_id: int):\n", + "\n", + " _sos = torch.tensor([sos],\n", + " dtype=torch.long,\n", + " requires_grad=False,\n", + " device=ys_pad.device)\n", + " _eos = torch.tensor([eos],\n", + " dtype=torch.long,\n", + " requires_grad=False,\n", + " device=ys_pad.device)\n", + " ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys\n", + " ys_in = [torch.cat([_sos, y], dim=0) for y in ys]\n", + " ys_out = [torch.cat([y, _eos], dim=0) for y in ys]\n", + " return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "ff0c2400", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[4232, 2995, 3116, 1209, 565, 4232, 4232],\n", + " [4232, 236, 1176, 331, 66, 3925, 4077],\n", + " [4232, 2693, 524, 234, 1145, 366, 4232],\n", + " [4232, 3875, 4211, 3062, 700, 4232, 4232],\n", + " [4232, 272, 987, 1134, 494, 2959, 4232],\n", + " [4232, 1936, 3715, 120, 2553, 2695, 2710],\n", + " [4232, 25, 1149, 3930, 4232, 4232, 4232],\n", + " [4232, 1753, 1778, 1237, 482, 3925, 110],\n", + " [4232, 3703, 2, 565, 3827, 4232, 4232],\n", + " [4232, 1150, 2734, 10, 2478, 3490, 4232],\n", + " [4232, 426, 811, 95, 489, 144, 4232],\n", + " [4232, 2313, 2006, 489, 975, 4232, 4232],\n", + " [4232, 3702, 3414, 205, 1488, 2966, 1347],\n", + " [4232, 70, 1741, 702, 1666, 4232, 4232],\n", + " [4232, 703, 1778, 1030, 849, 4232, 4232],\n", + " [4232, 814, 1674, 115, 3827, 4232, 4232]], device='cuda:0')\n", + "tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n", + " [ 236, 1176, 331, 66, 3925, 4077, 4232],\n", + " [2693, 524, 234, 1145, 366, 4232, -1],\n", + " [3875, 4211, 3062, 700, 4232, -1, -1],\n", + " [ 272, 987, 1134, 494, 2959, 4232, -1],\n", + " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", + " [ 25, 1149, 3930, 4232, -1, -1, -1],\n", + " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", + " [3703, 2, 565, 3827, 4232, -1, -1],\n", + " [1150, 2734, 10, 2478, 3490, 4232, -1],\n", + " [ 426, 811, 95, 489, 144, 4232, -1],\n", + " [2313, 2006, 489, 975, 4232, -1, -1],\n", + " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", + " [ 70, 1741, 702, 1666, 4232, -1, -1],\n", + " [ 703, 1778, 1030, 849, 4232, -1, -1],\n", + " [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n" + ] + } + ], + "source": [ + "ys_pad = text\n", + "ys_pad_lens = text_len\n", + "ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,\n", + " model.ignore_id)\n", + "ys_in_lens = ys_pad_lens + 1\n", + "print(ys_in_pad)\n", + "print(ys_out_pad)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "3e84da38", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([16, 7, 4233])\n", + "tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n", + " 1.5035e-02, 4.0337e-01],\n", + " [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n", + " -1.4353e-01, -1.0024e+00],\n", + " [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n", + " -1.1672e+00, -2.6849e-01],\n", + " ...,\n", + " [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n", + " 4.5470e-02, -3.7139e-01],\n", + " [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n", + " -7.9347e-04, 4.2538e-01],\n", + " [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n", + " -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=)\n" + ] + } + ], + "source": [ + "decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,\n", + " ys_in_lens)\n", + "print(decoder_out.shape)\n", + "print(decoder_out[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aac441ea", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "5ddbca73", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.float32\n", + "torch.int64\n", + "tensor(112., device='cuda:0')\n", + "tensor(830.9634, device='cuda:0', grad_fn=)\n", + "tensor([False, False, False, False, False, True, True, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " True, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, False, True, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, True, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, False, False, False, False, False, False, True,\n", + " False, False, False, False, False, False, True, False, False, False,\n", + " False, False, True, True, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, True, True, False, False, False, False, False,\n", + " True, True], device='cuda:0')\n", + "tensor(669.4634, device='cuda:0', grad_fn=)\n", + "tensor(41.8415, device='cuda:0', grad_fn=)\n", + "tensor([[2995, 3116, 1209, 565, 4232, -1, -1],\n", + " [ 236, 1176, 331, 66, 3925, 4077, 4232],\n", + " [2693, 524, 234, 1145, 366, 4232, -1],\n", + " [3875, 4211, 3062, 700, 4232, -1, -1],\n", + " [ 272, 987, 1134, 494, 2959, 4232, -1],\n", + " [1936, 3715, 120, 2553, 2695, 2710, 4232],\n", + " [ 25, 1149, 3930, 4232, -1, -1, -1],\n", + " [1753, 1778, 1237, 482, 3925, 110, 4232],\n", + " [3703, 2, 565, 3827, 4232, -1, -1],\n", + " [1150, 2734, 10, 2478, 3490, 4232, -1],\n", + " [ 426, 811, 95, 489, 144, 4232, -1],\n", + " [2313, 2006, 489, 975, 4232, -1, -1],\n", + " [3702, 3414, 205, 1488, 2966, 1347, 4232],\n", + " [ 70, 1741, 702, 1666, 4232, -1, -1],\n", + " [ 703, 1778, 1030, 849, 4232, -1, -1],\n", + " [ 814, 1674, 115, 3827, 4232, -1, -1]], device='cuda:0')\n", + "tensor([[-3.7639e-01, -8.2272e-01, 7.4276e-01, ..., 3.4201e-01,\n", + " 1.5035e-02, 4.0337e-01],\n", + " [-8.7386e-01, -3.1389e-01, 4.1988e-01, ..., 3.7724e-01,\n", + " -1.4353e-01, -1.0024e+00],\n", + " [-4.3505e-01, 3.4505e-02, -2.8710e-01, ..., 7.7274e-02,\n", + " -1.1672e+00, -2.6849e-01],\n", + " ...,\n", + " [ 4.2471e-01, 5.8886e-01, 2.0204e-02, ..., 3.7405e-01,\n", + " 4.5470e-02, -3.7139e-01],\n", + " [-3.7978e-01, -8.1084e-01, 7.5725e-01, ..., 2.6039e-01,\n", + " -7.9347e-04, 4.2538e-01],\n", + " [-3.8280e-01, -8.1207e-01, 7.4943e-01, ..., 2.6173e-01,\n", + " -1.0499e-03, 4.2679e-01]], device='cuda:0', grad_fn=)\n" + ] + } + ], + "source": [ + "print(decoder_out.dtype)\n", + "print(ys_out_pad.dtype)\n", + "loss_att = model.criterion_att(decoder_out, ys_out_pad)\n", + "print(loss_att)\n", + "print(ys_out_pad)\n", + "print(decoder_out[0])\n", + "np.savez('/workspace/DeepSpeech-2.x/.notebook/decoder',\n", + " decoder_out=decoder_out.cpu().detach().numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78f98c0b", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "8d968cd3", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "\n", + "\n", + "class LabelSmoothingLoss(nn.Module):\n", + " def __init__(self,\n", + " size: int,\n", + " padding_idx: int,\n", + " smoothing: float,\n", + " normalize_length: bool = False):\n", + " \"\"\"Construct an LabelSmoothingLoss object.\"\"\"\n", + " super(LabelSmoothingLoss, self).__init__()\n", + " self.criterion = nn.KLDivLoss(reduction=\"none\")\n", + " self.padding_idx = padding_idx\n", + " self.confidence = 1.0 - smoothing\n", + " self.smoothing = smoothing\n", + " self.size = size\n", + " self.normalize_length = normalize_length\n", + "\n", + " def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"Compute loss between x and target.\n", + "\n", + " The model outputs and data labels tensors are flatten to\n", + " (batch*seqlen, class) shape and a mask is applied to the\n", + " padding part which should not be calculated for loss.\n", + "\n", + " Args:\n", + " x (torch.Tensor): prediction (batch, seqlen, class)\n", + " target (torch.Tensor):\n", + " target signal masked with self.padding_id (batch, seqlen)\n", + " Returns:\n", + " loss (torch.Tensor) : The KL loss, scalar float value\n", + " \"\"\"\n", + " assert x.size(2) == self.size\n", + " batch_size = x.size(0)\n", + " x = x.view(-1, self.size)\n", + " target = target.view(-1)\n", + " # use zeros_like instead of torch.no_grad() for true_dist,\n", + " # since no_grad() can not be exported by JIT\n", + " true_dist = torch.zeros_like(x)\n", + " true_dist.fill_(self.smoothing / (self.size - 1))\n", + " ignore = target == self.padding_idx # (B,)\n", + " print(self.smoothing / (self.size - 1))\n", + " print(true_dist)\n", + " total = len(target) - ignore.sum().item()\n", + " target = target.masked_fill(ignore, 0) # avoid -1 index\n", + " true_dist.scatter_(1, target.unsqueeze(1), self.confidence)\n", + " print(true_dist.dtype)\n", + " print(true_dist.square().sum())\n", + " kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)\n", + " print(kl.sum())\n", + " denom = total if self.normalize_length else batch_size\n", + " print(ignore)\n", + " numer= kl.masked_fill(ignore.unsqueeze(1), 0).sum()\n", + " print(numer)\n", + " return numer /denom" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "3df340ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.3629489603024576e-05\n", + "tensor([[2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", + " 2.3629e-05],\n", + " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", + " 2.3629e-05],\n", + " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", + " 2.3629e-05],\n", + " ...,\n", + " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", + " 2.3629e-05],\n", + " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", + " 2.3629e-05],\n", + " [2.3629e-05, 2.3629e-05, 2.3629e-05, ..., 2.3629e-05, 2.3629e-05,\n", + " 2.3629e-05]], device='cuda:0')\n", + "torch.float32\n", + "tensor(90.7203, device='cuda:0')\n", + "tensor(830.9634, device='cuda:0', grad_fn=)\n", + "tensor([False, False, False, False, False, True, True, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " True, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, False, True, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, True, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, False, False, False, False, False, False, True,\n", + " False, False, False, False, False, False, True, False, False, False,\n", + " False, False, True, True, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, False, False,\n", + " False, False, False, True, True, False, False, False, False, False,\n", + " True, True], device='cuda:0')\n", + "tensor(669.4634, device='cuda:0', grad_fn=)\n", + "tensor(41.8415, device='cuda:0', grad_fn=)\n", + "torch.int64\n" + ] + } + ], + "source": [ + "criteron = LabelSmoothingLoss(4233, -1, 0.1, False)\n", + "loss_att = criteron(decoder_out, ys_out_pad)\n", + "print(loss_att)\n", + "print(ys_out_pad.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "badc410d", + "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdecoder_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \"\"\"\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/workspace/wenet/venv/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m allow_unreachable=True) # allow_unreachable flag\n", + "\u001b[0;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time." + ] + } + ], + "source": [ + "loss_att.backward()\n", + "print(loss_att.grad)\n", + "print(decoder_out.grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "219eb41f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([ 0.0024, 0.0019, -0.1098, ..., 0.0028, 0.0020, -1.7978],\n", + " device='cuda:0')\n", + "tensor([[ 6.5052e-04, 6.4419e-05, -6.1955e-06, ..., 9.8220e-04,\n", + " -2.5918e-05, 3.3754e-04],\n", + " [ 3.9305e-04, 4.5799e-04, 1.4362e-04, ..., 4.6800e-04,\n", + " 1.6911e-04, 2.7067e-04],\n", + " [-1.3593e-01, 5.2201e-02, 3.2895e-02, ..., 2.4580e-02,\n", + " 1.4590e-01, -4.6850e-02],\n", + " ...,\n", + " [ 1.0434e-03, 4.2251e-04, 6.5688e-04, ..., 1.2144e-03,\n", + " 2.1159e-04, 6.6838e-04],\n", + " [ 6.4997e-04, 4.4301e-04, 4.1550e-04, ..., 1.0420e-03,\n", + " 2.4114e-04, 1.5338e-04],\n", + " [-9.9337e-01, 5.4573e-01, -1.1371e-02, ..., -4.3175e-01,\n", + " -2.7850e-01, -4.4679e-01]], device='cuda:0')\n" + ] + } + ], + "source": [ + "print(model.decoder.output_layer.bias.grad)\n", + "print(model.decoder.output_layer.weight.grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "40d00a54", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[-5.3698e-01, -1.9911e-01, -3.4997e-01, ..., -8.2428e-01,\n", + " -1.0265e+00, -9.6301e-01],\n", + " [-4.4642e-02, 2.3176e-01, -3.2539e-01, ..., -9.0159e-01,\n", + " -1.0325e+00, -7.5987e-01],\n", + " [ 5.0035e-01, 2.2691e-01, -7.3052e-01, ..., -1.0055e+00,\n", + " -8.7123e-01, -1.0306e+00],\n", + " ...,\n", + " [-4.0024e-01, -1.4325e-01, -5.7947e-01, ..., -1.0718e+00,\n", + " -1.2806e+00, -1.0518e+00],\n", + " [ 1.5755e-01, -1.8495e-03, -2.8703e-01, ..., -1.1090e+00,\n", + " -9.4519e-01, -7.2506e-01],\n", + " [-4.7520e-01, -1.3942e+00, -2.5754e-01, ..., -1.1365e+00,\n", + " -1.1943e+00, -1.2290e+00]],\n", + "\n", + " [[ 9.5454e-01, 3.6428e-01, -1.3891e+00, ..., -1.1637e+00,\n", + " -1.2845e+00, -1.2015e+00],\n", + " [-8.5735e-02, -1.0579e+00, -8.9173e-01, ..., -9.6441e-01,\n", + " -1.1255e+00, -1.2599e+00],\n", + " [ 4.7654e-01, 3.2887e-01, -5.9201e-01, ..., -1.1942e+00,\n", + " -1.1430e+00, -1.0242e+00],\n", + " ...,\n", + " [-4.7431e-01, -3.3559e-01, -7.2326e-01, ..., -1.4506e+00,\n", + " -1.3957e+00, -1.0464e+00],\n", + " [ 3.6113e-01, 1.0381e-01, -1.1599e+00, ..., -1.0439e+00,\n", + " -1.0221e+00, -1.0208e+00],\n", + " [-1.2717e+00, -2.1460e+00, -7.5677e-01, ..., -9.7822e-01,\n", + " -9.3785e-01, -1.0371e+00]],\n", + "\n", + " [[-1.5465e+00, -1.0152e+00, -8.8901e-01, ..., -4.8522e-01,\n", + " -7.5163e-01, -6.7765e-01],\n", + " [-7.6101e-01, -7.3352e-01, -9.1588e-01, ..., -2.4836e-01,\n", + " -5.8927e-01, -7.3723e-01],\n", + " [-2.4714e-02, 1.7016e-01, -4.2326e-01, ..., -3.3204e-01,\n", + " -7.6696e-01, -7.1652e-01],\n", + " ...,\n", + " [-1.7032e+00, -1.2591e+00, -1.1449e+00, ..., -1.1810e+00,\n", + " -1.1163e+00, -9.3108e-01],\n", + " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", + " -3.9869e+00, -3.6797e+00],\n", + " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", + " -3.9869e+00, -3.6797e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 6.4983e-01, 2.6117e-01, -8.4197e-01, ..., -8.7213e-01,\n", + " -1.1073e+00, -1.3253e+00],\n", + " [ 3.5391e-01, -1.5846e-02, -4.0425e-01, ..., -9.9173e-01,\n", + " -1.0727e+00, -1.1924e+00],\n", + " [ 3.7704e-01, -6.2785e-02, -1.1468e-01, ..., -1.1021e+00,\n", + " -1.0952e+00, -1.1182e+00],\n", + " ...,\n", + " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", + " -3.9869e+00, -3.6797e+00],\n", + " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", + " -3.9869e+00, -3.6797e+00],\n", + " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", + " -3.9869e+00, -3.6797e+00]],\n", + "\n", + " [[ 4.4458e-02, -1.7547e-01, -6.7475e-01, ..., -4.9801e-01,\n", + " -5.6783e-01, -7.7852e-01],\n", + " [-1.3428e+00, -8.0343e-01, -9.0457e-01, ..., -6.5902e-01,\n", + " -7.2550e-01, -6.2796e-01],\n", + " [-7.6253e-01, -1.3071e-01, -1.3280e-01, ..., -5.6133e-01,\n", + " -6.0588e-01, -7.2115e-01],\n", + " ...,\n", + " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", + " -3.9869e+00, -3.6797e+00],\n", + " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", + " -3.9869e+00, -3.6797e+00],\n", + " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", + " -3.9869e+00, -3.6797e+00]],\n", + "\n", + " [[-1.0798e+00, -1.0834e+00, -1.1797e+00, ..., -1.7757e-01,\n", + " -4.3747e-01, -4.0007e-02],\n", + " [ 9.2354e-01, 6.3771e-01, -5.2810e-01, ..., -1.2928e-01,\n", + " -2.0342e-01, 1.6656e-01],\n", + " [ 4.9337e-01, -9.1133e-03, -7.3302e-01, ..., 1.0074e-01,\n", + " -9.8115e-02, -9.2357e-03],\n", + " ...,\n", + " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", + " -3.9869e+00, -3.6797e+00],\n", + " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", + " -3.9869e+00, -3.6797e+00],\n", + " [-6.0434e+00, -4.9397e+00, -3.4235e+00, ..., -3.9949e+00,\n", + " -3.9869e+00, -3.6797e+00]]], device='cuda:0')\n" + ] + } + ], + "source": [ + "xs = model.encoder.global_cmvn(feat)\n", + "print(xs)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "505ca294", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[ True, True, True, ..., True, True, True]],\n", + "\n", + " [[ True, True, True, ..., True, True, True]],\n", + "\n", + " [[ True, True, True, ..., True, False, False]],\n", + "\n", + " ...,\n", + "\n", + " [[ True, True, True, ..., False, False, False]],\n", + "\n", + " [[ True, True, True, ..., False, False, False]],\n", + "\n", + " [[ True, True, True, ..., False, False, False]]], device='cuda:0')\n" + ] + } + ], + "source": [ + "from wenet.utils.mask import make_pad_mask\n", + "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", + "print(masks)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "aa03c2b9", + "metadata": {}, + "outputs": [], + "source": [ + "xs, pos_emb, masks = model.encoder.embed(xs, masks)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "ebc0ea12", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[-0.5482, 2.2866, -1.0750, ..., 1.4504, 0.2895, -0.6945],\n", + " [-0.8013, 1.7688, -1.6639, ..., 1.8332, 0.6791, -0.2000],\n", + " [-1.7112, 2.7057, -1.3363, ..., 1.2336, 0.1870, -0.5735],\n", + " ...,\n", + " [-0.9697, 2.3129, -0.8752, ..., 0.8584, 0.4853, -0.4177],\n", + " [-1.3609, 2.1779, -1.7813, ..., 2.0928, 0.2528, -0.3650],\n", + " [-1.6967, 2.3544, -1.7417, ..., 1.3670, 0.5951, -0.7415]],\n", + "\n", + " [[-1.9828, 2.3178, -0.9079, ..., 0.4117, 0.5006, 0.0872],\n", + " [-0.7640, 1.3558, -1.3613, ..., 0.7317, 0.6784, 0.1685],\n", + " [-0.9504, 1.6038, -1.3030, ..., 0.5754, 0.2677, 0.3343],\n", + " ...,\n", + " [-1.4757, 2.5317, -1.2321, ..., 1.2997, 0.5019, -0.1034],\n", + " [-1.1731, 2.3172, -1.2542, ..., 1.7391, 0.2171, -0.4445],\n", + " [-1.2700, 3.2229, -0.8872, ..., 1.6461, 0.0973, -0.7679]],\n", + "\n", + " [[-0.5873, 1.4291, -1.3950, ..., 0.2102, 0.1027, 0.0918],\n", + " [ 0.1743, 1.7834, -1.6422, ..., 0.8113, 0.3137, 0.5634],\n", + " [-0.3492, 1.8310, -1.0685, ..., 0.6924, 0.1378, 0.4594],\n", + " ...,\n", + " [-1.0869, 2.3002, -1.2638, ..., 1.7998, 0.5134, -0.5223],\n", + " [-1.2614, 2.7240, -1.3734, ..., 1.4445, 0.5742, -0.3320],\n", + " [-2.2068, 4.3462, -3.8289, ..., 2.1426, 1.2034, -1.3795]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.3914, 1.8553, -0.5747, ..., 1.0062, 0.4632, -1.0452],\n", + " [-0.8605, 2.0172, -1.4437, ..., 1.4526, 0.1657, 0.5923],\n", + " [-0.7307, 2.2841, -1.0699, ..., 1.5825, -0.0980, 0.5503],\n", + " ...,\n", + " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", + " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", + " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n", + "\n", + " [[-0.1619, 0.6255, -1.1323, ..., 0.0724, -0.2204, 0.4636],\n", + " [-0.0831, 0.5750, -1.0930, ..., 0.9110, -0.0650, 0.7299],\n", + " [-0.2820, 0.0801, -0.9418, ..., 0.3379, -0.1166, 0.4451],\n", + " ...,\n", + " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", + " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", + " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]],\n", + "\n", + " [[-0.5458, -0.6909, -1.3597, ..., -0.7818, 0.6875, 0.9843],\n", + " [ 0.0421, -1.1062, -1.4389, ..., -0.0239, 0.9115, 0.5287],\n", + " [-0.2909, -0.1886, -1.5487, ..., -0.1392, 0.0580, 0.3066],\n", + " ...,\n", + " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", + " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270],\n", + " [-5.0821, 8.5920, -4.2137, ..., 6.2693, 0.0539, -2.9270]]],\n", + " device='cuda:0', grad_fn=)\n", + "tensor([[[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,\n", + " 0.0000e+00, 1.0000e+00],\n", + " [ 8.4147e-01, 5.4030e-01, 8.0196e-01, ..., 1.0000e+00,\n", + " 1.0746e-04, 1.0000e+00],\n", + " [ 9.0930e-01, -4.1615e-01, 9.5814e-01, ..., 1.0000e+00,\n", + " 2.1492e-04, 1.0000e+00],\n", + " ...,\n", + " [-7.6825e-01, -6.4014e-01, 6.3280e-01, ..., 9.9998e-01,\n", + " 5.1581e-03, 9.9999e-01],\n", + " [-9.5375e-01, 3.0059e-01, 9.9899e-01, ..., 9.9998e-01,\n", + " 5.2656e-03, 9.9999e-01],\n", + " [-2.6237e-01, 9.6497e-01, 5.6075e-01, ..., 9.9998e-01,\n", + " 5.3730e-03, 9.9999e-01]]], device='cuda:0')\n", + "tensor([[[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " False]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " False]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, False,\n", + " False]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, False, False, False,\n", + " False]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, False, False, False,\n", + " False]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, False, False, False,\n", + " False]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, False, False, False,\n", + " False]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, False, False, False, False, False,\n", + " False]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, False, False, False, False, False, False, False, False,\n", + " False]],\n", + "\n", + " [[ True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, False, False, False, False, False, False, False, False, False,\n", + " False]]], device='cuda:0')\n", + "torch.Size([16, 1, 51])\n" + ] + } + ], + "source": [ + "print(xs)\n", + "print(pos_emb)\n", + "print(masks)\n", + "print(masks.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "4289461b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n", + " -0.6945408 ]\n", + " [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n", + " -0.1999542 ]\n", + " [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n", + " -0.5735198 ]\n", + " ...\n", + " [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n", + " -0.41773027]\n", + " [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n", + " -0.36496443]\n", + " [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n", + " -0.74147725]]\n", + "\n", + " [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n", + " 0.08721463]\n", + " [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n", + " 0.16851945]\n", + " [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n", + " 0.33433008]\n", + " ...\n", + " [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n", + " -0.10343577]\n", + " [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n", + " -0.44447583]\n", + " [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n", + " -0.7678688 ]]\n", + "\n", + " [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n", + " 0.09179455]\n", + " [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n", + " 0.56344515]\n", + " [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n", + " 0.45937473]\n", + " ...\n", + " [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n", + " -0.52227837]\n", + " [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n", + " -0.33201432]\n", + " [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n", + " -1.3795122 ]]\n", + "\n", + " ...\n", + "\n", + " [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n", + " -1.045236 ]\n", + " [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n", + " 0.5923172 ]\n", + " [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n", + " 0.55030036]\n", + " ...\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]]\n", + "\n", + " [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n", + " 0.46362036]\n", + " [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n", + " 0.72986233]\n", + " [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n", + " 0.44514441]\n", + " ...\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]]\n", + "\n", + " [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n", + " 0.9842716 ]\n", + " [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n", + " 0.52870303]\n", + " [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n", + " 0.30663735]\n", + " ...\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]]]\n" + ] + } + ], + "source": [ + "xs = model.encoder.global_cmvn(feat)\n", + "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", + "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n", + "print(xs.cpu().detach().numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "67e10d73", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 2.0908e-03],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 1.1943e-02, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 4.6105e-02, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 9.6723e-03,\n", + " 4.6135e-02, 0.0000e+00]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[2.2816e-01, 2.4615e-01, 2.5304e-01, ..., 2.0402e-01,\n", + " 2.3248e-01, 3.1191e-01],\n", + " [1.3587e-01, 2.8877e-01, 2.7991e-01, ..., 1.9210e-01,\n", + " 2.0346e-01, 1.9934e-01],\n", + " [2.5739e-01, 3.9348e-01, 2.7877e-01, ..., 2.7483e-01,\n", + " 1.9302e-01, 2.3810e-01],\n", + " ...,\n", + " [1.1939e-01, 2.8473e-01, 3.3082e-01, ..., 2.3838e-01,\n", + " 2.2104e-01, 2.3906e-01],\n", + " [1.7388e-01, 2.0402e-01, 4.0263e-01, ..., 2.4782e-01,\n", + " 2.6742e-01, 1.5427e-01],\n", + " [0.0000e+00, 2.9081e-01, 2.7726e-01, ..., 1.7540e-01,\n", + " 1.8479e-01, 2.2483e-01]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[3.5447e-01, 3.8861e-01, 3.9724e-01, ..., 3.8680e-01,\n", + " 3.3568e-01, 3.4552e-01],\n", + " [4.1739e-01, 5.1039e-01, 4.1730e-01, ..., 3.3993e-01,\n", + " 3.7082e-01, 3.5110e-01],\n", + " [3.6117e-01, 4.0745e-01, 4.8491e-01, ..., 3.4849e-01,\n", + " 3.2321e-01, 3.5189e-01],\n", + " ...,\n", + " [2.3144e-01, 3.8021e-01, 5.1526e-01, ..., 3.6499e-01,\n", + " 3.7412e-01, 3.9986e-01],\n", + " [3.4679e-01, 4.0238e-01, 5.0077e-01, ..., 3.6185e-01,\n", + " 3.1597e-01, 3.6335e-01],\n", + " [3.6498e-01, 3.7943e-01, 5.1719e-01, ..., 3.1798e-01,\n", + " 3.3657e-01, 3.4130e-01]]],\n", + "\n", + "\n", + " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[1.4560e-02, 9.4475e-02, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [1.5002e-02, 2.9632e-02, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [3.2952e-02, 0.0000e+00, 0.0000e+00, ..., 4.5850e-02,\n", + " 2.0439e-02, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 4.4258e-02],\n", + " [0.0000e+00, 0.0000e+00, 2.5565e-02, ..., 0.0000e+00,\n", + " 9.0044e-03, 4.9084e-02]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [1.1141e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[3.3697e-01, 3.8527e-01, 3.2900e-01, ..., 2.8704e-01,\n", + " 2.3351e-01, 1.9004e-01],\n", + " [1.3575e-01, 3.5783e-01, 3.3573e-01, ..., 2.2082e-01,\n", + " 1.5855e-01, 1.3587e-01],\n", + " [2.1929e-01, 2.8900e-01, 2.8255e-01, ..., 2.0603e-01,\n", + " 2.3927e-01, 2.1909e-01],\n", + " ...,\n", + " [2.3292e-01, 3.9097e-01, 3.6399e-01, ..., 2.0598e-01,\n", + " 2.5374e-01, 2.3137e-01],\n", + " [1.8739e-01, 3.0794e-01, 3.0297e-01, ..., 2.7251e-01,\n", + " 2.5192e-01, 2.0837e-01],\n", + " [2.2454e-01, 4.1402e-01, 5.4083e-01, ..., 3.1875e-01,\n", + " 2.5080e-01, 2.5939e-01]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[2.6457e-01, 4.9519e-01, 5.6702e-01, ..., 3.0955e-01,\n", + " 3.5292e-01, 3.2669e-01],\n", + " [2.1577e-01, 5.1833e-01, 4.9183e-01, ..., 3.6043e-01,\n", + " 3.8524e-01, 3.6155e-01],\n", + " [2.0068e-01, 4.2784e-01, 5.2818e-01, ..., 3.1871e-01,\n", + " 3.2452e-01, 3.1036e-01],\n", + " ...,\n", + " [4.9855e-01, 5.1001e-01, 5.2279e-01, ..., 3.6450e-01,\n", + " 3.4338e-01, 3.3603e-01],\n", + " [4.1233e-01, 5.5518e-01, 5.2828e-01, ..., 4.0676e-01,\n", + " 3.3873e-01, 3.6724e-01],\n", + " [4.0820e-01, 4.6187e-01, 4.7338e-01, ..., 3.8691e-01,\n", + " 3.6039e-01, 3.8022e-01]]],\n", + "\n", + "\n", + " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[0.0000e+00, 5.7852e-03, 0.0000e+00, ..., 7.4838e-03,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 3.0351e-02,\n", + " 0.0000e+00, 2.6720e-04],\n", + " [9.4807e-04, 0.0000e+00, 0.0000e+00, ..., 7.9551e-03,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [2.0326e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 1.0801e-02, 0.0000e+00],\n", + " [1.8470e-01, 0.0000e+00, 0.0000e+00, ..., 5.0584e-02,\n", + " 9.4758e-02, 5.9146e-02]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[3.8708e-01, 2.8022e-01, 3.5893e-01, ..., 1.6595e-01,\n", + " 1.6031e-01, 2.1136e-01],\n", + " [1.5595e-01, 3.0544e-01, 2.4666e-01, ..., 2.2675e-01,\n", + " 2.5765e-01, 1.9682e-01],\n", + " [2.9518e-01, 4.1210e-01, 2.0063e-01, ..., 1.7595e-01,\n", + " 2.2537e-01, 2.2214e-01],\n", + " ...,\n", + " [2.4745e-01, 2.6259e-01, 3.8654e-01, ..., 2.3620e-01,\n", + " 2.3157e-01, 1.8514e-01],\n", + " [2.5715e-01, 2.9593e-01, 4.7745e-01, ..., 2.3546e-01,\n", + " 2.5073e-01, 2.0976e-01],\n", + " [1.2015e+00, 8.4644e-01, 7.3386e-01, ..., 1.0252e+00,\n", + " 9.5310e-01, 1.0013e+00]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[4.5013e-01, 4.7484e-01, 4.0540e-01, ..., 1.9346e-01,\n", + " 1.7826e-01, 1.4777e-01],\n", + " [4.7546e-01, 4.8187e-01, 3.6760e-01, ..., 2.7809e-01,\n", + " 3.2997e-01, 3.2337e-01],\n", + " [4.6160e-01, 4.0050e-01, 3.9061e-01, ..., 3.6613e-01,\n", + " 3.5243e-01, 2.9739e-01],\n", + " ...,\n", + " [5.5148e-01, 5.1018e-01, 4.0132e-01, ..., 3.8948e-01,\n", + " 3.5737e-01, 3.3088e-01],\n", + " [4.1973e-01, 4.5475e-01, 4.5320e-01, ..., 3.8343e-01,\n", + " 4.0126e-01, 3.6181e-01],\n", + " [3.4280e-01, 3.1606e-01, 4.4701e-01, ..., 2.1665e-01,\n", + " 2.3985e-01, 2.3903e-01]]],\n", + "\n", + "\n", + " ...,\n", + "\n", + "\n", + " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [4.1783e-02, 0.0000e+00, 1.5805e-02, ..., 0.0000e+00,\n", + " 2.2508e-02, 0.0000e+00],\n", + " [4.3234e-02, 7.7864e-02, 0.0000e+00, ..., 1.6347e-02,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[3.2092e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [1.3563e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[0.0000e+00, 2.5187e-01, 2.4979e-01, ..., 2.4775e-01,\n", + " 2.2354e-01, 1.9149e-01],\n", + " [1.6541e-01, 1.9586e-01, 1.9813e-01, ..., 2.7344e-01,\n", + " 2.0928e-01, 2.6150e-01],\n", + " [1.0495e-01, 6.3299e-02, 3.3844e-01, ..., 2.5138e-01,\n", + " 1.2470e-01, 2.3927e-01],\n", + " ...,\n", + " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", + " 1.0094e+00, 1.0221e+00],\n", + " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", + " 1.0094e+00, 1.0221e+00],\n", + " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", + " 1.0094e+00, 1.0221e+00]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[1.1428e-01, 4.5667e-01, 4.6821e-01, ..., 3.2058e-01,\n", + " 3.3579e-01, 3.9013e-01],\n", + " [1.0441e-01, 4.5739e-01, 4.6107e-01, ..., 3.8468e-01,\n", + " 3.8291e-01, 3.6686e-01],\n", + " [1.9868e-01, 3.5520e-01, 4.4313e-01, ..., 4.0679e-01,\n", + " 3.8068e-01, 3.0646e-01],\n", + " ...,\n", + " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", + " 1.2189e+00, 1.2380e+00],\n", + " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", + " 1.2189e+00, 1.2380e+00],\n", + " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", + " 1.2189e+00, 1.2380e+00]]],\n", + "\n", + "\n", + " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[2.4654e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 3.3902e-02],\n", + " [0.0000e+00, 0.0000e+00, 1.8307e-02, ..., 5.1669e-02,\n", + " 9.4838e-03, 7.4535e-02],\n", + " [9.9215e-02, 0.0000e+00, 1.5872e-02, ..., 1.6203e-02,\n", + " 5.1401e-02, 1.9239e-03],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[4.0034e-01, 2.5306e-01, 2.0218e-01, ..., 9.8162e-02,\n", + " 7.0643e-02, 4.9741e-02],\n", + " [1.2568e-01, 2.1031e-01, 1.1182e-01, ..., 4.2781e-02,\n", + " 1.1969e-01, 1.2005e-01],\n", + " [2.8787e-01, 2.4031e-01, 2.2566e-01, ..., 0.0000e+00,\n", + " 6.4181e-02, 5.8730e-02],\n", + " ...,\n", + " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", + " 1.0094e+00, 1.0221e+00],\n", + " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", + " 1.0094e+00, 1.0221e+00],\n", + " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", + " 1.0094e+00, 1.0221e+00]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[3.8405e-01, 3.0990e-01, 3.7156e-01, ..., 1.8125e-01,\n", + " 1.5051e-01, 1.9620e-01],\n", + " [4.7286e-01, 4.0529e-01, 3.9718e-01, ..., 2.4710e-01,\n", + " 4.5657e-02, 1.1501e-01],\n", + " [3.2621e-01, 3.0073e-01, 3.0477e-01, ..., 2.3529e-01,\n", + " 2.1357e-01, 1.6986e-01],\n", + " ...,\n", + " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", + " 1.2189e+00, 1.2380e+00],\n", + " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", + " 1.2189e+00, 1.2380e+00],\n", + " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", + " 1.2189e+00, 1.2380e+00]]],\n", + "\n", + "\n", + " [[[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[3.3438e-02, 1.2378e-03, 5.2972e-02, ..., 7.2712e-02,\n", + " 8.6563e-02, 1.4494e-01],\n", + " [1.1043e-01, 6.1431e-02, 6.3630e-02, ..., 8.1278e-02,\n", + " 6.2590e-02, 8.3154e-02],\n", + " [1.7677e-02, 2.0111e-03, 7.8750e-02, ..., 6.9633e-02,\n", + " 8.9799e-02, 5.3263e-02],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[1.0034e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [1.5627e-01, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [5.1447e-02, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 4.3641e-03],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[2.5142e-01, 4.5964e-01, 3.7346e-01, ..., 4.7631e-02,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [1.9760e-01, 2.6627e-01, 1.1191e-01, ..., 3.0450e-02,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [1.6341e-01, 3.2938e-01, 2.5690e-01, ..., 5.5694e-02,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", + " 1.0094e+00, 1.0221e+00],\n", + " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", + " 1.0094e+00, 1.0221e+00],\n", + " [1.1257e+00, 8.7341e-01, 7.8169e-01, ..., 1.0458e+00,\n", + " 1.0094e+00, 1.0221e+00]],\n", + "\n", + " [[0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 2.2189e-02, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 2.8490e-02],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " ...,\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00],\n", + " [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,\n", + " 0.0000e+00, 0.0000e+00]],\n", + "\n", + " [[2.5810e-01, 6.3017e-01, 3.7038e-01, ..., 1.8704e-01,\n", + " 8.2694e-02, 9.9127e-02],\n", + " [1.7293e-01, 5.0679e-01, 4.0739e-01, ..., 1.6006e-01,\n", + " 1.1725e-01, 9.9405e-02],\n", + " [2.4175e-01, 4.1616e-01, 4.1257e-01, ..., 1.3520e-01,\n", + " 7.9126e-02, 1.2846e-01],\n", + " ...,\n", + " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", + " 1.2189e+00, 1.2380e+00],\n", + " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", + " 1.2189e+00, 1.2380e+00],\n", + " [1.4488e+00, 1.0212e+00, 9.4473e-01, ..., 1.2363e+00,\n", + " 1.2189e+00, 1.2380e+00]]]], device='cuda:0',\n", + " grad_fn=)\n" + ] + } + ], + "source": [ + "xs = model.encoder.global_cmvn(feat)\n", + "masks = ~make_pad_mask(feat_len).unsqueeze(1) # (B, 1, L)\n", + "\n", + "x = xs.unsqueeze(1)\n", + "x = model.encoder.embed.conv(x)\n", + "print(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "9a9478ad", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[-0.03426375 0.14291267 -0.06718873 ... 0.09064753 0.01809387\n", + " -0.0434088 ]\n", + " [-0.05007839 0.11054724 -0.10399298 ... 0.11457238 0.04244684\n", + " -0.01249714]\n", + " [-0.10695291 0.16910909 -0.08352133 ... 0.07710276 0.01168563\n", + " -0.03584499]\n", + " ...\n", + " [-0.06060536 0.14455931 -0.05470302 ... 0.05364908 0.03033342\n", + " -0.02610814]\n", + " [-0.08505894 0.13611752 -0.11132983 ... 0.13079923 0.01580139\n", + " -0.02281028]\n", + " [-0.10604677 0.14714901 -0.10885533 ... 0.08543444 0.03719445\n", + " -0.04634233]]\n", + "\n", + " [[-0.12392755 0.14486063 -0.05674079 ... 0.02573164 0.03128851\n", + " 0.00545091]\n", + " [-0.04775286 0.08473608 -0.08507854 ... 0.04573154 0.04240163\n", + " 0.01053247]\n", + " [-0.05940291 0.10023535 -0.0814373 ... 0.035965 0.01673085\n", + " 0.02089563]\n", + " ...\n", + " [-0.09222981 0.15823206 -0.07700447 ... 0.08122957 0.03136991\n", + " -0.00646474]\n", + " [-0.07331756 0.14482647 -0.07838815 ... 0.1086944 0.01356864\n", + " -0.02777974]\n", + " [-0.07937264 0.20143102 -0.05544947 ... 0.10287814 0.00608235\n", + " -0.0479918 ]]\n", + "\n", + " [[-0.03670349 0.0893159 -0.08718812 ... 0.0131405 0.00642052\n", + " 0.00573716]\n", + " [ 0.01089254 0.11146393 -0.10263617 ... 0.05070438 0.01960694\n", + " 0.03521532]\n", + " [-0.0218228 0.11443964 -0.06678198 ... 0.04327708 0.00861394\n", + " 0.02871092]\n", + " ...\n", + " [-0.06792898 0.14376275 -0.07899005 ... 0.11248926 0.03208683\n", + " -0.0326424 ]\n", + " [-0.07884051 0.17024788 -0.08583611 ... 0.09028331 0.03588808\n", + " -0.0207509 ]\n", + " [-0.13792302 0.27163863 -0.23930418 ... 0.13391261 0.0752104\n", + " -0.08621951]]\n", + "\n", + " ...\n", + "\n", + " [[-0.02446348 0.11595841 -0.03591986 ... 0.0628897 0.02895011\n", + " -0.06532725]\n", + " [-0.05378424 0.1260737 -0.09023033 ... 0.09078894 0.01035743\n", + " 0.03701983]\n", + " [-0.04566649 0.14275314 -0.0668687 ... 0.09890588 -0.00612222\n", + " 0.03439377]\n", + " ...\n", + " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", + " -0.18293698]\n", + " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", + " -0.18293698]\n", + " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", + " -0.18293698]]\n", + "\n", + " [[-0.01012144 0.03909408 -0.07077143 ... 0.00452683 -0.01377654\n", + " 0.02897627]\n", + " [-0.00519154 0.03594019 -0.06831125 ... 0.05693541 -0.00406374\n", + " 0.0456164 ]\n", + " [-0.01762631 0.00500899 -0.05886075 ... 0.02112178 -0.00729015\n", + " 0.02782153]\n", + " ...\n", + " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", + " -0.18293698]\n", + " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", + " -0.18293698]\n", + " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", + " -0.18293698]]\n", + "\n", + " [[-0.03411558 -0.04318277 -0.08497842 ... -0.04886402 0.04296734\n", + " 0.06151697]\n", + " [ 0.00263296 -0.06913657 -0.08993219 ... -0.00149064 0.05696633\n", + " 0.03304394]\n", + " [-0.01818341 -0.0117864 -0.09679577 ... -0.00870231 0.00362198\n", + " 0.01916483]\n", + " ...\n", + " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", + " -0.18293698]\n", + " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", + " -0.18293698]\n", + " [-0.31763062 0.5370021 -0.2633542 ... 0.39182857 0.00337184\n", + " -0.18293698]]]\n", + "torch.Size([16, 51, 256])\n" + ] + } + ], + "source": [ + "b, c, t, f = x.size()\n", + "x = model.encoder.embed.out(x.transpose(1, 2).contiguous().view(b, t, c * f))\n", + "print(x.cpu().detach().numpy())\n", + "print(x.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "fd69003f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[-0.54822 2.2866027 -1.0750197 ... 1.4503604 0.28950194\n", + " -0.6945408 ]\n", + " [-0.8012542 1.7687558 -1.6638877 ... 1.833158 0.6791494\n", + " -0.1999542 ]\n", + " [-1.7112465 2.7057455 -1.3363413 ... 1.2336441 0.18697014\n", + " -0.5735198 ]\n", + " ...\n", + " [-0.96968573 2.312949 -0.87524825 ... 0.85838526 0.4853347\n", + " -0.41773027]\n", + " [-1.3609431 2.1778803 -1.7812773 ... 2.0927877 0.25282228\n", + " -0.36496443]\n", + " [-1.6967483 2.3543842 -1.7416853 ... 1.366951 0.59511113\n", + " -0.74147725]]\n", + "\n", + " [[-1.9828408 2.31777 -0.9078527 ... 0.41170627 0.5006162\n", + " 0.08721463]\n", + " [-0.76404583 1.3557773 -1.3612567 ... 0.7317046 0.678426\n", + " 0.16851945]\n", + " [-0.95044655 1.6037656 -1.3029968 ... 0.57544005 0.26769355\n", + " 0.33433008]\n", + " ...\n", + " [-1.475677 2.531713 -1.2320715 ... 1.2996731 0.50191855\n", + " -0.10343577]\n", + " [-1.1730809 2.3172235 -1.2542105 ... 1.7391105 0.21709818\n", + " -0.44447583]\n", + " [-1.2699623 3.2228963 -0.8871915 ... 1.6460502 0.09731755\n", + " -0.7678688 ]]\n", + "\n", + " [[-0.5872559 1.4290544 -1.3950099 ... 0.21024795 0.10272825\n", + " 0.09179455]\n", + " [ 0.1742807 1.783423 -1.6421788 ... 0.8112701 0.31371105\n", + " 0.56344515]\n", + " [-0.34916472 1.8310343 -1.0685117 ... 0.69243336 0.13782299\n", + " 0.45937473]\n", + " ...\n", + " [-1.0868638 2.300204 -1.2638408 ... 1.7998282 0.5133892\n", + " -0.52227837]\n", + " [-1.2614481 2.7239661 -1.3733778 ... 1.444533 0.57420933\n", + " -0.33201432]\n", + " [-2.2067683 4.346218 -3.828867 ... 2.1426017 1.2033664\n", + " -1.3795122 ]]\n", + "\n", + " ...\n", + "\n", + " [[-0.39141566 1.8553346 -0.5747178 ... 1.0062351 0.46320182\n", + " -1.045236 ]\n", + " [-0.86054784 2.0171793 -1.4436853 ... 1.452623 0.16571884\n", + " 0.5923172 ]\n", + " [-0.73066384 2.2840502 -1.0698992 ... 1.5824941 -0.0979555\n", + " 0.55030036]\n", + " ...\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]]\n", + "\n", + " [[-0.16194311 0.6255052 -1.1323429 ... 0.07242929 -0.22042468\n", + " 0.46362036]\n", + " [-0.08306468 0.575043 -1.09298 ... 0.9109665 -0.06501988\n", + " 0.72986233]\n", + " [-0.28202093 0.08014385 -0.9417719 ... 0.3379485 -0.11664233\n", + " 0.44514441]\n", + " ...\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]]\n", + "\n", + " [[-0.5458492 -0.69092435 -1.3596548 ... -0.78182435 0.68747747\n", + " 0.9842716 ]\n", + " [ 0.04212743 -1.1061852 -1.438915 ... -0.02385022 0.91146135\n", + " 0.52870303]\n", + " [-0.2909345 -0.18858244 -1.5487324 ... -0.13923697 0.05795169\n", + " 0.30663735]\n", + " ...\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]\n", + " [-5.08209 8.592033 -4.2136674 ... 6.269257 0.05394945\n", + " -2.9269917 ]]]\n" + ] + } + ], + "source": [ + "x, pos_emb = model.encoder.embed.pos_enc(x, 0)\n", + "print(x.cpu().detach().numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "8ed88489", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.float32\n", + "[[[ 0.0000000e+00 1.0000000e+00 0.0000000e+00 ... 1.0000000e+00\n", + " 0.0000000e+00 1.0000000e+00]\n", + " [ 8.4147096e-01 5.4030234e-01 8.0196178e-01 ... 1.0000000e+00\n", + " 1.0746076e-04 1.0000000e+00]\n", + " [ 9.0929741e-01 -4.1614684e-01 9.5814437e-01 ... 1.0000000e+00\n", + " 2.1492151e-04 1.0000000e+00]\n", + " ...\n", + " [-7.6825464e-01 -6.4014435e-01 6.3279724e-01 ... 9.9998462e-01\n", + " 5.1580933e-03 9.9998671e-01]\n", + " [-9.5375264e-01 3.0059254e-01 9.9899054e-01 ... 9.9998397e-01\n", + " 5.2655530e-03 9.9998611e-01]\n", + " [-2.6237485e-01 9.6496606e-01 5.6074661e-01 ... 9.9998331e-01\n", + " 5.3730118e-03 9.9998558e-01]]]\n" + ] + } + ], + "source": [ + "print(pos_emb.dtype)\n", + "print(pos_emb.cpu().detach().numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "5e277881", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([16, 51, 256])\n" + ] + }, + { + "ename": "NameError", + "evalue": "name 'mask' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0mpos_emb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpos_emb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 143\u001b[0;31m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 144\u001b[0m \u001b[0mx_att\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_att\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 145\u001b[0m )\n", + "\u001b[0;31mNameError\u001b[0m: name 'mask' is not defined" + ] + } + ], + "source": [ + "def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,\n", + " use_dynamic_chunk: bool,\n", + " use_dynamic_left_chunk: bool,\n", + " decoding_chunk_size: int, static_chunk_size: int,\n", + " num_decoding_left_chunks: int):\n", + " \"\"\" Apply optional mask for encoder.\n", + " Args:\n", + " xs (torch.Tensor): padded input, (B, L, D), L for max length\n", + " mask (torch.Tensor): mask for xs, (B, 1, L)\n", + " use_dynamic_chunk (bool): whether to use dynamic chunk or not\n", + " use_dynamic_left_chunk (bool): whether to use dynamic left chunk for\n", + " training.\n", + " decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's\n", + " 0: default for training, use random dynamic chunk.\n", + " <0: for decoding, use full chunk.\n", + " >0: for decoding, use fixed chunk size as set.\n", + " static_chunk_size (int): chunk size for static chunk training/decoding\n", + " if it's greater than 0, if use_dynamic_chunk is true,\n", + " this parameter will be ignored\n", + " num_decoding_left_chunks: number of left chunks, this is for decoding,\n", + " the chunk size is decoding_chunk_size.\n", + " >=0: use num_decoding_left_chunks\n", + " <0: use all left chunks\n", + " Returns:\n", + " torch.Tensor: chunk mask of the input xs.\n", + " \"\"\"\n", + " # Whether to use chunk mask or not\n", + " if use_dynamic_chunk:\n", + " max_len = xs.size(1)\n", + " if decoding_chunk_size < 0:\n", + " chunk_size = max_len\n", + " num_left_chunks = -1\n", + " elif decoding_chunk_size > 0:\n", + " chunk_size = decoding_chunk_size\n", + " num_left_chunks = num_decoding_left_chunks\n", + " else:\n", + " # chunk size is either [1, 25] or full context(max_len).\n", + " # Since we use 4 times subsampling and allow up to 1s(100 frames)\n", + " # delay, the maximum frame is 100 / 4 = 25.\n", + " chunk_size = torch.randint(1, max_len, (1, )).item()\n", + " num_left_chunks = -1\n", + " if chunk_size > max_len // 2:\n", + " chunk_size = max_len\n", + " else:\n", + " chunk_size = chunk_size % 25 + 1\n", + " if use_dynamic_left_chunk:\n", + " max_left_chunks = (max_len - 1) // chunk_size\n", + " num_left_chunks = torch.randint(0, max_left_chunks,\n", + " (1, )).item()\n", + " chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,\n", + " num_left_chunks,\n", + " xs.device) # (L, L)\n", + " chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n", + " chunk_masks = masks & chunk_masks # (B, L, L)\n", + " elif static_chunk_size > 0:\n", + " num_left_chunks = num_decoding_left_chunks\n", + " chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,\n", + " num_left_chunks,\n", + " xs.device) # (L, L)\n", + " chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)\n", + " chunk_masks = masks & chunk_masks # (B, L, L)\n", + " else:\n", + " chunk_masks = masks\n", + " return chunk_masks\n", + "\n", + "from wenet.utils.mask import make_pad_mask\n", + "\n", + "\n", + "masks = ~make_pad_mask(feat_len).unsqueeze(1)\n", + "xs = model.encoder.global_cmvn(feat)\n", + "xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)\n", + "\n", + "mask_pad = masks\n", + "decoding_chunk_size=0\n", + "num_decoding_left_chunks=-1\n", + "use_dynamic_left_chunk=-1\n", + "use_dynamic_chunk=False\n", + "static_chunk_size=-1\n", + "chunk_masks = add_optional_chunk_mask(\n", + " xs, \n", + " masks, \n", + " use_dynamic_chunk,\n", + " use_dynamic_left_chunk,\n", + " decoding_chunk_size, \n", + " static_chunk_size,\n", + " num_decoding_left_chunks)\n", + "\n", + "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_embed', \n", + " embed_out=xs.cpu().detach().numpy(), \n", + " pos_emb=pos_emb.cpu().detach().numpy(),\n", + " chunk_masks=chunk_masks.cpu().detach().numpy(),\n", + " mask_pad=mask_pad.cpu().detach().numpy())\n", + "\n", + "model.eval()\n", + "# print(chunk_masks)\n", + "print(xs.shape)\n", + "for layer in model.encoder.encoders:\n", + " #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", + " #np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0', enc_0=xs.cpu().detach().numpy())\n", + " \n", + " x = xs\n", + " residual = x\n", + " x_norm = layer.norm_ff_macaron(x)\n", + " !rm /workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff.npz\n", + " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff', \n", + " norm_ff=x_norm.cpu().detach().numpy(),\n", + " xs=xs.cpu().detach().numpy())\n", + " #print(x.cpu().detach().numpy())\n", + " for p in layer.norm_ff_macaron.parameters():\n", + " #print(p, p.sum())\n", + " pass\n", + " \n", + " x = residual + layer.ff_scale * layer.feed_forward_macaron(x_norm)\n", + " \n", + " ps = []\n", + " for n, p in layer.feed_forward_macaron.state_dict().items():\n", + " #print(n, p.cpu().data.numpy())\n", + " ps.append(p.cpu().data.numpy())\n", + " pass\n", + "\n", + " ff_l_x = layer.feed_forward_macaron.w_1(x_norm)\n", + " ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)\n", + " ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)\n", + " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_ff_out', \n", + " norm_ff=x_norm.cpu().detach().numpy(),\n", + " ff_out=x.cpu().detach().numpy(),\n", + " ff_l_x = ff_l_x.cpu().detach().numpy(),\n", + " ff_l_a_x=ff_l_a_x.cpu().detach().numpy(),\n", + " ff_l_a_l_x=ff_l_a_l_x.cpu().detach().numpy(),\n", + " ps=ps,\n", + " )\n", + " \n", + " \n", + " residual = x\n", + " x = layer.norm_mha(x)\n", + " x_q = x\n", + " \n", + " x_att = layer.self_attn(x_q, x, x, pos_emb, masks)\n", + " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_selattn_out', \n", + " x_q=x_q.cpu().detach().numpy(),\n", + " x=x.cpu().detach().numpy(),\n", + " pos_emb = pos_emb.cpu().detach().numpy(),\n", + " mask=mask.cpu().detach().numpy(),\n", + " x_att=x_att.cpu().detach().numpy(),\n", + " )\n", + " \n", + " break\n", + "#print(xs.cpu().detach().numpy())\n", + "\n", + "\n", + "i = 0\n", + "for layer in model.encoder.encoders:\n", + " xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)\n", + " i += 1\n", + " if i == 2:\n", + " np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_2', enc_2=xs.cpu().detach().numpy())\n", + " \n", + "np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_all', enc_all=xs.cpu().detach().numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c43fd4f1", + "metadata": {}, + "outputs": [], + "source": [ + "out, mask = model.encoder(feat, feat_len)\n", + "#print(out.cpu().detach().numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e73db22", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f506114", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}